mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-27 15:58:25 +00:00
Implement configurable context length (#1749)
This commit is contained in:
parent
7aa0f779de
commit
d1c56b8b28
@ -714,8 +714,9 @@ Bert::~Bert() {
|
|||||||
bert_free(d_ptr->ctx);
|
bert_free(d_ptr->ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Bert::loadModel(const std::string &modelPath)
|
bool Bert::loadModel(const std::string &modelPath, int n_ctx)
|
||||||
{
|
{
|
||||||
|
(void)n_ctx;
|
||||||
d_ptr->ctx = bert_load_from_file(modelPath.c_str());
|
d_ptr->ctx = bert_load_from_file(modelPath.c_str());
|
||||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
d_ptr->modelLoaded = d_ptr->ctx != nullptr;
|
d_ptr->modelLoaded = d_ptr->ctx != nullptr;
|
||||||
@ -728,8 +729,10 @@ bool Bert::isModelLoaded() const
|
|||||||
return d_ptr->modelLoaded;
|
return d_ptr->modelLoaded;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Bert::requiredMem(const std::string &/*modelPath*/)
|
size_t Bert::requiredMem(const std::string &modelPath, int n_ctx)
|
||||||
{
|
{
|
||||||
|
(void)modelPath;
|
||||||
|
(void)n_ctx;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,9 +18,9 @@ public:
|
|||||||
|
|
||||||
bool supportsEmbedding() const override { return true; }
|
bool supportsEmbedding() const override { return true; }
|
||||||
bool supportsCompletion() const override { return true; }
|
bool supportsCompletion() const override { return true; }
|
||||||
bool loadModel(const std::string &modelPath) override;
|
bool loadModel(const std::string &modelPath, int n_ctx) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
size_t requiredMem(const std::string &modelPath) override;
|
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
|
||||||
size_t stateSize() const override;
|
size_t stateSize() const override;
|
||||||
size_t saveState(uint8_t *dest) const override;
|
size_t saveState(uint8_t *dest) const override;
|
||||||
size_t restoreState(const uint8_t *src) override;
|
size_t restoreState(const uint8_t *src) override;
|
||||||
|
@ -676,7 +676,8 @@ GPTJ::GPTJ()
|
|||||||
d_ptr->modelLoaded = false;
|
d_ptr->modelLoaded = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GPTJ::requiredMem(const std::string &modelPath) {
|
size_t GPTJ::requiredMem(const std::string &modelPath, int n_ctx) {
|
||||||
|
(void)n_ctx;
|
||||||
gptj_model dummy_model;
|
gptj_model dummy_model;
|
||||||
gpt_vocab dummy_vocab;
|
gpt_vocab dummy_vocab;
|
||||||
size_t mem_req;
|
size_t mem_req;
|
||||||
@ -684,7 +685,8 @@ size_t GPTJ::requiredMem(const std::string &modelPath) {
|
|||||||
return mem_req;
|
return mem_req;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPTJ::loadModel(const std::string &modelPath) {
|
bool GPTJ::loadModel(const std::string &modelPath, int n_ctx) {
|
||||||
|
(void)n_ctx;
|
||||||
std::mt19937 rng(time(NULL));
|
std::mt19937 rng(time(NULL));
|
||||||
d_ptr->rng = rng;
|
d_ptr->rng = rng;
|
||||||
|
|
||||||
|
@ -17,9 +17,9 @@ public:
|
|||||||
|
|
||||||
bool supportsEmbedding() const override { return false; }
|
bool supportsEmbedding() const override { return false; }
|
||||||
bool supportsCompletion() const override { return true; }
|
bool supportsCompletion() const override { return true; }
|
||||||
bool loadModel(const std::string &modelPath) override;
|
bool loadModel(const std::string &modelPath, int n_ctx) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
size_t requiredMem(const std::string &modelPath) override;
|
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
|
||||||
size_t stateSize() const override;
|
size_t stateSize() const override;
|
||||||
size_t saveState(uint8_t *dest) const override;
|
size_t saveState(uint8_t *dest) const override;
|
||||||
size_t restoreState(const uint8_t *src) override;
|
size_t restoreState(const uint8_t *src) override;
|
||||||
|
@ -120,7 +120,8 @@ struct llama_file_hparams {
|
|||||||
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
|
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t LLamaModel::requiredMem(const std::string &modelPath) {
|
size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx) {
|
||||||
|
// TODO(cebtenzzre): update to GGUF
|
||||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||||
fin.seekg(0, std::ios_base::end);
|
fin.seekg(0, std::ios_base::end);
|
||||||
size_t filesize = fin.tellg();
|
size_t filesize = fin.tellg();
|
||||||
@ -137,40 +138,31 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) {
|
|||||||
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
|
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
|
||||||
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
|
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
|
||||||
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
|
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
|
||||||
const size_t n_ctx = 2048;
|
|
||||||
const size_t kvcache_element_size = 2; // fp16
|
const size_t kvcache_element_size = 2; // fp16
|
||||||
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
|
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
|
||||||
return filesize + est_kvcache_size;
|
return filesize + est_kvcache_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LLamaModel::loadModel(const std::string &modelPath)
|
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx)
|
||||||
{
|
{
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
// load the model
|
if (n_ctx < 8) {
|
||||||
|
std::cerr << "warning: minimum context size is 8, using minimum size.\n";
|
||||||
|
n_ctx = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- load the model --
|
||||||
|
|
||||||
d_ptr->model_params = llama_model_default_params();
|
d_ptr->model_params = llama_model_default_params();
|
||||||
|
|
||||||
d_ptr->model_params.use_mmap = params.use_mmap;
|
d_ptr->model_params.use_mmap = params.use_mmap;
|
||||||
#if defined (__APPLE__)
|
#if defined (__APPLE__)
|
||||||
d_ptr->model_params.use_mlock = true;
|
d_ptr->model_params.use_mlock = true;
|
||||||
#else
|
#else
|
||||||
d_ptr->model_params.use_mlock = params.use_mlock;
|
d_ptr->model_params.use_mlock = params.use_mlock;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
d_ptr->ctx_params = llama_context_default_params();
|
|
||||||
|
|
||||||
d_ptr->ctx_params.n_ctx = 2048;
|
|
||||||
d_ptr->ctx_params.seed = params.seed;
|
|
||||||
d_ptr->ctx_params.f16_kv = params.memory_f16;
|
|
||||||
|
|
||||||
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
|
|
||||||
// that we want this many logits so the state serializes consistently.
|
|
||||||
d_ptr->ctx_params.logits_all = true;
|
|
||||||
|
|
||||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
|
||||||
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
|
|
||||||
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
|
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (llama_verbose()) {
|
if (llama_verbose()) {
|
||||||
std::cerr << "llama.cpp: using Metal" << std::endl;
|
std::cerr << "llama.cpp: using Metal" << std::endl;
|
||||||
@ -197,6 +189,28 @@ bool LLamaModel::loadModel(const std::string &modelPath)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
|
||||||
|
if (n_ctx > n_ctx_train) {
|
||||||
|
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
|
||||||
|
<< n_ctx << " specified)\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- initialize the context --
|
||||||
|
|
||||||
|
d_ptr->ctx_params = llama_context_default_params();
|
||||||
|
|
||||||
|
d_ptr->ctx_params.n_ctx = n_ctx;
|
||||||
|
d_ptr->ctx_params.seed = params.seed;
|
||||||
|
d_ptr->ctx_params.f16_kv = params.memory_f16;
|
||||||
|
|
||||||
|
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
|
||||||
|
// that we want this many logits so the state serializes consistently.
|
||||||
|
d_ptr->ctx_params.logits_all = true;
|
||||||
|
|
||||||
|
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
|
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
|
||||||
|
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
|
||||||
|
|
||||||
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
|
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
|
||||||
if (!d_ptr->ctx) {
|
if (!d_ptr->ctx) {
|
||||||
#ifdef GGML_USE_KOMPUTE
|
#ifdef GGML_USE_KOMPUTE
|
||||||
|
@ -17,9 +17,9 @@ public:
|
|||||||
|
|
||||||
bool supportsEmbedding() const override { return false; }
|
bool supportsEmbedding() const override { return false; }
|
||||||
bool supportsCompletion() const override { return true; }
|
bool supportsCompletion() const override { return true; }
|
||||||
bool loadModel(const std::string &modelPath) override;
|
bool loadModel(const std::string &modelPath, int n_ctx) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
size_t requiredMem(const std::string &modelPath) override;
|
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
|
||||||
size_t stateSize() const override;
|
size_t stateSize() const override;
|
||||||
size_t saveState(uint8_t *dest) const override;
|
size_t saveState(uint8_t *dest) const override;
|
||||||
size_t restoreState(const uint8_t *src) override;
|
size_t restoreState(const uint8_t *src) override;
|
||||||
|
@ -138,7 +138,7 @@ const LLModel::Implementation* LLModel::Implementation::implementation(const cha
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant) {
|
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant, int n_ctx) {
|
||||||
if (!has_at_least_minimal_hardware()) {
|
if (!has_at_least_minimal_hardware()) {
|
||||||
std::cerr << "LLModel ERROR: CPU does not support AVX\n";
|
std::cerr << "LLModel ERROR: CPU does not support AVX\n";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -154,7 +154,11 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
|
|||||||
if(impl) {
|
if(impl) {
|
||||||
LLModel* metalimpl = impl->m_construct();
|
LLModel* metalimpl = impl->m_construct();
|
||||||
metalimpl->m_implementation = impl;
|
metalimpl->m_implementation = impl;
|
||||||
size_t req_mem = metalimpl->requiredMem(modelPath);
|
/* TODO(cebtenzzre): after we fix requiredMem, we should change this to happen at
|
||||||
|
* load time, not construct time. right now n_ctx is incorrectly hardcoded 2048 in
|
||||||
|
* most (all?) places where this is called, causing underestimation of required
|
||||||
|
* memory. */
|
||||||
|
size_t req_mem = metalimpl->requiredMem(modelPath, n_ctx);
|
||||||
float req_to_total = (float) req_mem / (float) total_mem;
|
float req_to_total = (float) req_mem / (float) total_mem;
|
||||||
// on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not
|
// on a 16GB M2 Mac a 13B q4_0 (0.52) works for me but a 13B q4_K_M (0.55) does not
|
||||||
if (req_to_total >= 0.53) {
|
if (req_to_total >= 0.53) {
|
||||||
@ -165,6 +169,8 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
(void)n_ctx;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (!impl) {
|
if (!impl) {
|
||||||
|
@ -37,7 +37,7 @@ public:
|
|||||||
static bool isImplementation(const Dlhandle&);
|
static bool isImplementation(const Dlhandle&);
|
||||||
static const std::vector<Implementation>& implementationList();
|
static const std::vector<Implementation>& implementationList();
|
||||||
static const Implementation *implementation(const char *fname, const std::string& buildVariant);
|
static const Implementation *implementation(const char *fname, const std::string& buildVariant);
|
||||||
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto");
|
static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto", int n_ctx = 2048);
|
||||||
static std::vector<GPUDevice> availableGPUDevices();
|
static std::vector<GPUDevice> availableGPUDevices();
|
||||||
static void setImplementationsSearchPath(const std::string& path);
|
static void setImplementationsSearchPath(const std::string& path);
|
||||||
static const std::string& implementationsSearchPath();
|
static const std::string& implementationsSearchPath();
|
||||||
@ -74,9 +74,9 @@ public:
|
|||||||
|
|
||||||
virtual bool supportsEmbedding() const = 0;
|
virtual bool supportsEmbedding() const = 0;
|
||||||
virtual bool supportsCompletion() const = 0;
|
virtual bool supportsCompletion() const = 0;
|
||||||
virtual bool loadModel(const std::string &modelPath) = 0;
|
virtual bool loadModel(const std::string &modelPath, int n_ctx) = 0;
|
||||||
virtual bool isModelLoaded() const = 0;
|
virtual bool isModelLoaded() const = 0;
|
||||||
virtual size_t requiredMem(const std::string &modelPath) = 0;
|
virtual size_t requiredMem(const std::string &modelPath, int n_ctx) = 0;
|
||||||
virtual size_t stateSize() const { return 0; }
|
virtual size_t stateSize() const { return 0; }
|
||||||
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
|
virtual size_t saveState(uint8_t */*dest*/) const { return 0; }
|
||||||
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
|
virtual size_t restoreState(const uint8_t */*src*/) { return 0; }
|
||||||
|
@ -47,16 +47,16 @@ void llmodel_model_destroy(llmodel_model model) {
|
|||||||
delete reinterpret_cast<LLModelWrapper*>(model);
|
delete reinterpret_cast<LLModelWrapper*>(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llmodel_required_mem(llmodel_model model, const char *model_path)
|
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx)
|
||||||
{
|
{
|
||||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||||
return wrapper->llModel->requiredMem(model_path);
|
return wrapper->llModel->requiredMem(model_path, n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llmodel_loadModel(llmodel_model model, const char *model_path)
|
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx)
|
||||||
{
|
{
|
||||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||||
return wrapper->llModel->loadModel(model_path);
|
return wrapper->llModel->loadModel(model_path, n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llmodel_isModelLoaded(llmodel_model model)
|
bool llmodel_isModelLoaded(llmodel_model model)
|
||||||
|
@ -110,17 +110,19 @@ void llmodel_model_destroy(llmodel_model model);
|
|||||||
* Estimate RAM requirement for a model file
|
* Estimate RAM requirement for a model file
|
||||||
* @param model A pointer to the llmodel_model instance.
|
* @param model A pointer to the llmodel_model instance.
|
||||||
* @param model_path A string representing the path to the model file.
|
* @param model_path A string representing the path to the model file.
|
||||||
|
* @param n_ctx Maximum size of context window
|
||||||
* @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed.
|
* @return size greater than 0 if the model was parsed successfully, 0 if file could not be parsed.
|
||||||
*/
|
*/
|
||||||
size_t llmodel_required_mem(llmodel_model model, const char *model_path);
|
size_t llmodel_required_mem(llmodel_model model, const char *model_path, int n_ctx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load a model from a file.
|
* Load a model from a file.
|
||||||
* @param model A pointer to the llmodel_model instance.
|
* @param model A pointer to the llmodel_model instance.
|
||||||
* @param model_path A string representing the path to the model file.
|
* @param model_path A string representing the path to the model file.
|
||||||
|
* @param n_ctx Maximum size of context window
|
||||||
* @return true if the model was loaded successfully, false otherwise.
|
* @return true if the model was loaded successfully, false otherwise.
|
||||||
*/
|
*/
|
||||||
bool llmodel_loadModel(llmodel_model model, const char *model_path);
|
bool llmodel_loadModel(llmodel_model model, const char *model_path, int n_ctx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if a model is loaded.
|
* Check if a model is loaded.
|
||||||
|
@ -188,7 +188,7 @@ public class LLModel : ILLModel
|
|||||||
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
||||||
public bool Load(string modelPath)
|
public bool Load(string modelPath)
|
||||||
{
|
{
|
||||||
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
return NativeMethods.llmodel_loadModel(_handle, modelPath, 2048);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void Destroy()
|
protected void Destroy()
|
||||||
|
@ -70,7 +70,8 @@ internal static unsafe partial class NativeMethods
|
|||||||
[return: MarshalAs(UnmanagedType.I1)]
|
[return: MarshalAs(UnmanagedType.I1)]
|
||||||
public static extern bool llmodel_loadModel(
|
public static extern bool llmodel_loadModel(
|
||||||
[NativeTypeName("llmodel_model")] IntPtr model,
|
[NativeTypeName("llmodel_model")] IntPtr model,
|
||||||
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path);
|
[NativeTypeName("const char *")][MarshalAs(UnmanagedType.LPUTF8Str)] string model_path,
|
||||||
|
[NativeTypeName("int32_t")] int n_ctx);
|
||||||
|
|
||||||
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
[DllImport("libllmodel", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
|
|||||||
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
|
var handle = NativeMethods.llmodel_model_create2(modelPath, "auto", out error);
|
||||||
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
|
_logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
|
||||||
_logger.LogInformation("Model loading started");
|
_logger.LogInformation("Model loading started");
|
||||||
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
|
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath, 2048);
|
||||||
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
|
_logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
|
||||||
if (!loadedSuccessfully)
|
if (!loadedSuccessfully)
|
||||||
{
|
{
|
||||||
|
@ -23,7 +23,7 @@ void* load_model(const char *fname, int n_threads) {
|
|||||||
fprintf(stderr, "%s: error '%s'\n", __func__, new_error);
|
fprintf(stderr, "%s: error '%s'\n", __func__, new_error);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (!llmodel_loadModel(model, fname)) {
|
if (!llmodel_loadModel(model, fname, 2048)) {
|
||||||
llmodel_model_destroy(model);
|
llmodel_model_destroy(model);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -195,7 +195,7 @@ public class LLModel implements AutoCloseable {
|
|||||||
if(model == null) {
|
if(model == null) {
|
||||||
throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0));
|
throw new IllegalStateException("Could not load, gpt4all backend returned error: " + error.getValue().getString(0));
|
||||||
}
|
}
|
||||||
library.llmodel_loadModel(model, modelPathAbs);
|
library.llmodel_loadModel(model, modelPathAbs, 2048);
|
||||||
|
|
||||||
if(!library.llmodel_isModelLoaded(model)){
|
if(!library.llmodel_isModelLoaded(model)){
|
||||||
throw new IllegalStateException("The model " + modelName + " could not be loaded");
|
throw new IllegalStateException("The model " + modelName + " could not be loaded");
|
||||||
|
@ -61,7 +61,7 @@ public interface LLModelLibrary {
|
|||||||
|
|
||||||
Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error);
|
Pointer llmodel_model_create2(String model_path, String build_variant, PointerByReference error);
|
||||||
void llmodel_model_destroy(Pointer model);
|
void llmodel_model_destroy(Pointer model);
|
||||||
boolean llmodel_loadModel(Pointer model, String model_path);
|
boolean llmodel_loadModel(Pointer model, String model_path, int n_ctx);
|
||||||
boolean llmodel_isModelLoaded(Pointer model);
|
boolean llmodel_isModelLoaded(Pointer model);
|
||||||
@u_int64_t long llmodel_get_state_size(Pointer model);
|
@u_int64_t long llmodel_get_state_size(Pointer model);
|
||||||
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);
|
@u_int64_t long llmodel_save_state_data(Pointer model, Pointer dest);
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
from .gpt4all import Embed4All, GPT4All # noqa
|
from .gpt4all import Embed4All as Embed4All, GPT4All as GPT4All
|
||||||
from .pyllmodel import LLModel # noqa
|
from .pyllmodel import LLModel as LLModel
|
||||||
|
@ -69,6 +69,7 @@ class GPT4All:
|
|||||||
allow_download: bool = True,
|
allow_download: bool = True,
|
||||||
n_threads: Optional[int] = None,
|
n_threads: Optional[int] = None,
|
||||||
device: Optional[str] = "cpu",
|
device: Optional[str] = "cpu",
|
||||||
|
n_ctx: int = 2048,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -90,15 +91,16 @@ class GPT4All:
|
|||||||
Default is "cpu".
|
Default is "cpu".
|
||||||
|
|
||||||
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
|
Note: If a selected GPU device does not have sufficient RAM to accommodate the model, an error will be thrown, and the GPT4All instance will be rendered invalid. It's advised to ensure the device has enough memory before initiating the model.
|
||||||
|
n_ctx: Maximum size of context window
|
||||||
|
verbose: If True, print debug messages.
|
||||||
"""
|
"""
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model = pyllmodel.LLModel()
|
self.model = pyllmodel.LLModel()
|
||||||
# Retrieve model and download if allowed
|
# Retrieve model and download if allowed
|
||||||
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
|
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
|
||||||
if device is not None:
|
if device is not None and device != "cpu":
|
||||||
if device != "cpu":
|
self.model.init_gpu(model_path=self.config["path"], device=device, n_ctx=n_ctx)
|
||||||
self.model.init_gpu(model_path=self.config["path"], device=device)
|
self.model.load_model(self.config["path"], n_ctx)
|
||||||
self.model.load_model(self.config["path"])
|
|
||||||
# Set n_threads
|
# Set n_threads
|
||||||
if n_threads is not None:
|
if n_threads is not None:
|
||||||
self.model.set_thread_count(n_threads)
|
self.model.set_thread_count(n_threads)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
@ -7,6 +9,7 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
from enum import Enum
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Callable, Iterable, List
|
from typing import Callable, Iterable, List
|
||||||
|
|
||||||
@ -72,9 +75,9 @@ llmodel.llmodel_model_create2.restype = ctypes.c_void_p
|
|||||||
llmodel.llmodel_model_destroy.argtypes = [ctypes.c_void_p]
|
llmodel.llmodel_model_destroy.argtypes = [ctypes.c_void_p]
|
||||||
llmodel.llmodel_model_destroy.restype = None
|
llmodel.llmodel_model_destroy.restype = None
|
||||||
|
|
||||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
|
||||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||||
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
llmodel.llmodel_required_mem.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int]
|
||||||
llmodel.llmodel_required_mem.restype = ctypes.c_size_t
|
llmodel.llmodel_required_mem.restype = ctypes.c_size_t
|
||||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||||
@ -114,7 +117,7 @@ llmodel.llmodel_set_implementation_search_path.restype = None
|
|||||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||||
|
|
||||||
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode("utf-8"))
|
llmodel.llmodel_set_implementation_search_path(str(MODEL_LIB_PATH).replace("\\", r"\\").encode())
|
||||||
|
|
||||||
llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)]
|
llmodel.llmodel_available_gpu_devices.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_int32)]
|
||||||
llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice)
|
llmodel.llmodel_available_gpu_devices.restype = ctypes.POINTER(LLModelGPUDevice)
|
||||||
@ -143,10 +146,16 @@ def _create_model(model_path: bytes) -> ctypes.c_void_p:
|
|||||||
err = ctypes.c_char_p()
|
err = ctypes.c_char_p()
|
||||||
model = llmodel.llmodel_model_create2(model_path, b"auto", ctypes.byref(err))
|
model = llmodel.llmodel_model_create2(model_path, b"auto", ctypes.byref(err))
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Unable to instantiate model: {err.decode()}")
|
s = err.value
|
||||||
|
raise ValueError("Unable to instantiate model: {'null' if s is None else s.decode()}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# Symbol to terminate from generator
|
||||||
|
class Sentinel(Enum):
|
||||||
|
TERMINATING_SYMBOL = 0
|
||||||
|
|
||||||
|
|
||||||
class LLModel:
|
class LLModel:
|
||||||
"""
|
"""
|
||||||
Base class and universal wrapper for GPT4All language models
|
Base class and universal wrapper for GPT4All language models
|
||||||
@ -173,12 +182,16 @@ class LLModel:
|
|||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
self.llmodel_lib.llmodel_model_destroy(self.model)
|
self.llmodel_lib.llmodel_model_destroy(self.model)
|
||||||
|
|
||||||
def memory_needed(self, model_path: str) -> int:
|
def memory_needed(self, model_path: str, n_ctx: int) -> int:
|
||||||
model_path_enc = model_path.encode("utf-8")
|
self.model = None
|
||||||
self.model = _create_model(model_path_enc)
|
return self._memory_needed(model_path, n_ctx)
|
||||||
return llmodel.llmodel_required_mem(self.model, model_path_enc)
|
|
||||||
|
|
||||||
def list_gpu(self, model_path: str) -> list:
|
def _memory_needed(self, model_path: str, n_ctx: int) -> int:
|
||||||
|
if self.model is None:
|
||||||
|
self.model = _create_model(model_path.encode())
|
||||||
|
return llmodel.llmodel_required_mem(self.model, model_path.encode(), n_ctx)
|
||||||
|
|
||||||
|
def list_gpu(self, model_path: str, n_ctx: int) -> list[LLModelGPUDevice]:
|
||||||
"""
|
"""
|
||||||
Lists available GPU devices that satisfy the model's memory requirements.
|
Lists available GPU devices that satisfy the model's memory requirements.
|
||||||
|
|
||||||
@ -186,45 +199,41 @@ class LLModel:
|
|||||||
----------
|
----------
|
||||||
model_path : str
|
model_path : str
|
||||||
Path to the model.
|
Path to the model.
|
||||||
|
n_ctx : int
|
||||||
|
Maximum size of context window
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
list
|
list
|
||||||
A list of LLModelGPUDevice structures representing available GPU devices.
|
A list of LLModelGPUDevice structures representing available GPU devices.
|
||||||
"""
|
"""
|
||||||
if self.model is not None:
|
mem_required = self._memory_needed(model_path, n_ctx)
|
||||||
model_path_enc = model_path.encode("utf-8")
|
return self._list_gpu(mem_required)
|
||||||
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
|
|
||||||
else:
|
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
|
||||||
mem_required = self.memory_needed(model_path)
|
|
||||||
num_devices = ctypes.c_int32(0)
|
num_devices = ctypes.c_int32(0)
|
||||||
devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
|
devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
|
||||||
if not devices_ptr:
|
if not devices_ptr:
|
||||||
raise ValueError("Unable to retrieve available GPU devices")
|
raise ValueError("Unable to retrieve available GPU devices")
|
||||||
devices = [devices_ptr[i] for i in range(num_devices.value)]
|
return devices_ptr[:num_devices.value]
|
||||||
return devices
|
|
||||||
|
|
||||||
def init_gpu(self, model_path: str, device: str):
|
def init_gpu(self, model_path: str, device: str, n_ctx: int):
|
||||||
if self.model is not None:
|
mem_required = self._memory_needed(model_path, n_ctx)
|
||||||
model_path_enc = model_path.encode("utf-8")
|
|
||||||
mem_required = llmodel.llmodel_required_mem(self.model, model_path_enc)
|
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode())
|
||||||
else:
|
|
||||||
mem_required = self.memory_needed(model_path)
|
|
||||||
device_enc = device.encode("utf-8")
|
|
||||||
success = self.llmodel_lib.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device_enc)
|
|
||||||
if not success:
|
if not success:
|
||||||
# Retrieve all GPUs without considering memory requirements.
|
# Retrieve all GPUs without considering memory requirements.
|
||||||
num_devices = ctypes.c_int32(0)
|
num_devices = ctypes.c_int32(0)
|
||||||
all_devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, 0, ctypes.byref(num_devices))
|
all_devices_ptr = self.llmodel_lib.llmodel_available_gpu_devices(self.model, 0, ctypes.byref(num_devices))
|
||||||
if not all_devices_ptr:
|
if not all_devices_ptr:
|
||||||
raise ValueError("Unable to retrieve list of all GPU devices")
|
raise ValueError("Unable to retrieve list of all GPU devices")
|
||||||
all_gpus = [all_devices_ptr[i].name.decode('utf-8') for i in range(num_devices.value)]
|
all_gpus = [d.name.decode() for d in all_devices_ptr[:num_devices.value]]
|
||||||
|
|
||||||
# Retrieve GPUs that meet the memory requirements using list_gpu
|
# Retrieve GPUs that meet the memory requirements using list_gpu
|
||||||
available_gpus = [device.name.decode('utf-8') for device in self.list_gpu(model_path)]
|
available_gpus = [device.name.decode() for device in self._list_gpu(mem_required)]
|
||||||
|
|
||||||
# Identify GPUs that are unavailable due to insufficient memory or features
|
# Identify GPUs that are unavailable due to insufficient memory or features
|
||||||
unavailable_gpus = set(all_gpus) - set(available_gpus)
|
unavailable_gpus = set(all_gpus).difference(available_gpus)
|
||||||
|
|
||||||
# Formulate the error message
|
# Formulate the error message
|
||||||
error_msg = "Unable to initialize model on GPU: '{}'.".format(device)
|
error_msg = "Unable to initialize model on GPU: '{}'.".format(device)
|
||||||
@ -232,7 +241,7 @@ class LLModel:
|
|||||||
error_msg += "\nUnavailable GPUs due to insufficient memory or features: {}.".format(unavailable_gpus)
|
error_msg += "\nUnavailable GPUs due to insufficient memory or features: {}.".format(unavailable_gpus)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
def load_model(self, model_path: str) -> bool:
|
def load_model(self, model_path: str, n_ctx: int) -> bool:
|
||||||
"""
|
"""
|
||||||
Load model from a file.
|
Load model from a file.
|
||||||
|
|
||||||
@ -240,15 +249,16 @@ class LLModel:
|
|||||||
----------
|
----------
|
||||||
model_path : str
|
model_path : str
|
||||||
Model filepath
|
Model filepath
|
||||||
|
n_ctx : int
|
||||||
|
Maximum size of context window
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
True if model loaded successfully, False otherwise
|
True if model loaded successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
model_path_enc = model_path.encode("utf-8")
|
self.model = _create_model(model_path.encode())
|
||||||
self.model = _create_model(model_path_enc)
|
|
||||||
|
|
||||||
llmodel.llmodel_loadModel(self.model, model_path_enc)
|
llmodel.llmodel_loadModel(self.model, model_path.encode(), n_ctx)
|
||||||
|
|
||||||
filename = os.path.basename(model_path)
|
filename = os.path.basename(model_path)
|
||||||
self.model_name = os.path.splitext(filename)[0]
|
self.model_name = os.path.splitext(filename)[0]
|
||||||
@ -312,7 +322,7 @@ class LLModel:
|
|||||||
raise ValueError("Text must not be None or empty")
|
raise ValueError("Text must not be None or empty")
|
||||||
|
|
||||||
embedding_size = ctypes.c_size_t()
|
embedding_size = ctypes.c_size_t()
|
||||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
c_text = ctypes.c_char_p(text.encode())
|
||||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||||
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
|
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
|
||||||
llmodel.llmodel_free_embedding(embedding_ptr)
|
llmodel.llmodel_free_embedding(embedding_ptr)
|
||||||
@ -357,7 +367,7 @@ class LLModel:
|
|||||||
prompt,
|
prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_bytes = prompt.encode("utf-8")
|
prompt_bytes = prompt.encode()
|
||||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||||
|
|
||||||
self._set_context(
|
self._set_context(
|
||||||
@ -385,10 +395,7 @@ class LLModel:
|
|||||||
def prompt_model_streaming(
|
def prompt_model_streaming(
|
||||||
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
# Symbol to terminate from generator
|
output_queue: Queue[str | Sentinel] = Queue()
|
||||||
TERMINATING_SYMBOL = object()
|
|
||||||
|
|
||||||
output_queue: Queue = Queue()
|
|
||||||
|
|
||||||
# Put response tokens into an output queue
|
# Put response tokens into an output queue
|
||||||
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
||||||
@ -405,7 +412,7 @@ class LLModel:
|
|||||||
|
|
||||||
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||||
self.prompt_model(prompt, callback, **kwargs)
|
self.prompt_model(prompt, callback, **kwargs)
|
||||||
output_queue.put(TERMINATING_SYMBOL)
|
output_queue.put(Sentinel.TERMINATING_SYMBOL)
|
||||||
|
|
||||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||||
# immediately
|
# immediately
|
||||||
@ -419,7 +426,7 @@ class LLModel:
|
|||||||
# Generator
|
# Generator
|
||||||
while True:
|
while True:
|
||||||
response = output_queue.get()
|
response = output_queue.get()
|
||||||
if response is TERMINATING_SYMBOL:
|
if isinstance(response, Sentinel):
|
||||||
break
|
break
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
@ -442,7 +449,7 @@ class LLModel:
|
|||||||
else:
|
else:
|
||||||
# beginning of a byte sequence
|
# beginning of a byte sequence
|
||||||
if len(self.buffer) > 0:
|
if len(self.buffer) > 0:
|
||||||
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
decoded.append(self.buffer.decode(errors='replace'))
|
||||||
|
|
||||||
self.buffer.clear()
|
self.buffer.clear()
|
||||||
|
|
||||||
@ -451,7 +458,7 @@ class LLModel:
|
|||||||
|
|
||||||
if self.buff_expecting_cont_bytes <= 0:
|
if self.buff_expecting_cont_bytes <= 0:
|
||||||
# received the whole sequence or an out of place continuation byte
|
# received the whole sequence or an out of place continuation byte
|
||||||
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
decoded.append(self.buffer.decode(errors='replace'))
|
||||||
|
|
||||||
self.buffer.clear()
|
self.buffer.clear()
|
||||||
self.buff_expecting_cont_bytes = 0
|
self.buff_expecting_cont_bytes = 0
|
||||||
|
@ -117,7 +117,7 @@ def test_empty_embedding():
|
|||||||
def test_download_model(tmp_path: Path):
|
def test_download_model(tmp_path: Path):
|
||||||
import gpt4all.gpt4all
|
import gpt4all.gpt4all
|
||||||
old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY
|
old_default_dir = gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY
|
||||||
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = tmp_path # temporary pytest directory to ensure a download happens
|
gpt4all.gpt4all.DEFAULT_MODEL_DIRECTORY = str(tmp_path) # temporary pytest directory to ensure a download happens
|
||||||
try:
|
try:
|
||||||
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
|
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
|
||||||
model_path = tmp_path / model.config['filename']
|
model_path = tmp_path / model.config['filename']
|
||||||
|
@ -28,7 +28,7 @@ Napi::Function NodeModelWrapper::GetClass(Napi::Env env) {
|
|||||||
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
||||||
{
|
{
|
||||||
auto env = info.Env();
|
auto env = info.Env();
|
||||||
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str()) ));
|
return Napi::Number::New(env, static_cast<uint32_t>( llmodel_required_mem(GetInference(), full_model_path.c_str(), 2048) ));
|
||||||
|
|
||||||
}
|
}
|
||||||
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
|
Napi::Value NodeModelWrapper::GetGpuDevices(const Napi::CallbackInfo& info)
|
||||||
@ -161,7 +161,7 @@ Napi::Value NodeModelWrapper::GetRequiredMemory(const Napi::CallbackInfo& info)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str());
|
auto success = llmodel_loadModel(GetInference(), full_weight_path.c_str(), 2048);
|
||||||
if(!success) {
|
if(!success) {
|
||||||
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
|
Napi::Error::New(env, "Failed to load model at given path").ThrowAsJavaScriptException();
|
||||||
return;
|
return;
|
||||||
|
@ -20,15 +20,17 @@ ChatGPT::ChatGPT()
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ChatGPT::requiredMem(const std::string &modelPath)
|
size_t ChatGPT::requiredMem(const std::string &modelPath, int n_ctx)
|
||||||
{
|
{
|
||||||
Q_UNUSED(modelPath);
|
Q_UNUSED(modelPath);
|
||||||
|
Q_UNUSED(n_ctx);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ChatGPT::loadModel(const std::string &modelPath)
|
bool ChatGPT::loadModel(const std::string &modelPath, int n_ctx)
|
||||||
{
|
{
|
||||||
Q_UNUSED(modelPath);
|
Q_UNUSED(modelPath);
|
||||||
|
Q_UNUSED(n_ctx);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,9 +48,9 @@ public:
|
|||||||
|
|
||||||
bool supportsEmbedding() const override { return false; }
|
bool supportsEmbedding() const override { return false; }
|
||||||
bool supportsCompletion() const override { return true; }
|
bool supportsCompletion() const override { return true; }
|
||||||
bool loadModel(const std::string &modelPath) override;
|
bool loadModel(const std::string &modelPath, int n_ctx) override;
|
||||||
bool isModelLoaded() const override;
|
bool isModelLoaded() const override;
|
||||||
size_t requiredMem(const std::string &modelPath) override;
|
size_t requiredMem(const std::string &modelPath, int n_ctx) override;
|
||||||
size_t stateSize() const override;
|
size_t stateSize() const override;
|
||||||
size_t saveState(uint8_t *dest) const override;
|
size_t saveState(uint8_t *dest) const override;
|
||||||
size_t restoreState(const uint8_t *src) override;
|
size_t restoreState(const uint8_t *src) override;
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <QDataStream>
|
#include <QDataStream>
|
||||||
|
|
||||||
#define CHAT_FORMAT_MAGIC 0xF5D553CC
|
#define CHAT_FORMAT_MAGIC 0xF5D553CC
|
||||||
#define CHAT_FORMAT_VERSION 6
|
#define CHAT_FORMAT_VERSION 7
|
||||||
|
|
||||||
class MyChatListModel: public ChatListModel { };
|
class MyChatListModel: public ChatListModel { };
|
||||||
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)
|
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)
|
||||||
|
@ -248,14 +248,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
|||||||
m_llModelInfo.model = model;
|
m_llModelInfo.model = model;
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
|
// TODO: make configurable in UI
|
||||||
|
auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
|
||||||
|
m_ctx.n_ctx = n_ctx;
|
||||||
|
|
||||||
|
std::string buildVariant = "auto";
|
||||||
#if defined(Q_OS_MAC) && defined(__arm__)
|
#if defined(Q_OS_MAC) && defined(__arm__)
|
||||||
if (m_forceMetal)
|
if (m_forceMetal)
|
||||||
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "metal");
|
buildVariant = "metal";
|
||||||
else
|
|
||||||
m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto");
|
|
||||||
#else
|
|
||||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
|
|
||||||
#endif
|
#endif
|
||||||
|
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||||
|
|
||||||
if (m_llModelInfo.model) {
|
if (m_llModelInfo.model) {
|
||||||
// Update the settings that a model is being loaded and update the device list
|
// Update the settings that a model is being loaded and update the device list
|
||||||
@ -267,7 +269,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
|||||||
if (requestedDevice == "CPU") {
|
if (requestedDevice == "CPU") {
|
||||||
emit reportFallbackReason(""); // fallback not applicable
|
emit reportFallbackReason(""); // fallback not applicable
|
||||||
} else {
|
} else {
|
||||||
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString());
|
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx);
|
||||||
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
|
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
|
||||||
LLModel::GPUDevice *device = nullptr;
|
LLModel::GPUDevice *device = nullptr;
|
||||||
|
|
||||||
@ -296,14 +298,14 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
|||||||
// Report which device we're actually using
|
// Report which device we're actually using
|
||||||
emit reportDevice(actualDevice);
|
emit reportDevice(actualDevice);
|
||||||
|
|
||||||
bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
|
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
|
||||||
if (actualDevice == "CPU") {
|
if (actualDevice == "CPU") {
|
||||||
// we asked llama.cpp to use the CPU
|
// we asked llama.cpp to use the CPU
|
||||||
} else if (!success) {
|
} else if (!success) {
|
||||||
// llama_init_from_file returned nullptr
|
// llama_init_from_file returned nullptr
|
||||||
emit reportDevice("CPU");
|
emit reportDevice("CPU");
|
||||||
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
|
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
|
||||||
success = m_llModelInfo.model->loadModel(filePath.toStdString());
|
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx);
|
||||||
} else if (!m_llModelInfo.model->usingGPUDevice()) {
|
} else if (!m_llModelInfo.model->usingGPUDevice()) {
|
||||||
// ggml_vk_init was not called in llama.cpp
|
// ggml_vk_init was not called in llama.cpp
|
||||||
// We might have had to fallback to CPU after load if the model is not possible to accelerate
|
// We might have had to fallback to CPU after load if the model is not possible to accelerate
|
||||||
@ -763,6 +765,8 @@ bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this function serialized the cached model state to disk.
|
||||||
|
// we want to also serialize n_ctx, and read it at load time.
|
||||||
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||||
{
|
{
|
||||||
if (version > 1) {
|
if (version > 1) {
|
||||||
@ -790,6 +794,9 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
|||||||
stream << responseLogits;
|
stream << responseLogits;
|
||||||
}
|
}
|
||||||
stream << m_ctx.n_past;
|
stream << m_ctx.n_past;
|
||||||
|
if (version >= 6) {
|
||||||
|
stream << m_ctx.n_ctx;
|
||||||
|
}
|
||||||
stream << quint64(m_ctx.logits.size());
|
stream << quint64(m_ctx.logits.size());
|
||||||
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
|
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
|
||||||
stream << quint64(m_ctx.tokens.size());
|
stream << quint64(m_ctx.tokens.size());
|
||||||
@ -839,6 +846,12 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
|||||||
stream >> n_past;
|
stream >> n_past;
|
||||||
if (!discardKV) m_ctx.n_past = n_past;
|
if (!discardKV) m_ctx.n_past = n_past;
|
||||||
|
|
||||||
|
if (version >= 6) {
|
||||||
|
uint32_t n_ctx;
|
||||||
|
stream >> n_ctx;
|
||||||
|
if (!discardKV) m_ctx.n_ctx = n_ctx;
|
||||||
|
}
|
||||||
|
|
||||||
quint64 logitsSize;
|
quint64 logitsSize;
|
||||||
stream >> logitsSize;
|
stream >> logitsSize;
|
||||||
if (!discardKV) {
|
if (!discardKV) {
|
||||||
|
@ -29,8 +29,8 @@ bool EmbeddingLLM::loadModel()
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
|
m_model = LLModel::Implementation::construct(filePath.toStdString());
|
||||||
bool success = m_model->loadModel(filePath.toStdString());
|
bool success = m_model->loadModel(filePath.toStdString(), 2048);
|
||||||
if (!success) {
|
if (!success) {
|
||||||
qWarning() << "WARNING: Could not load sbert";
|
qWarning() << "WARNING: Could not load sbert";
|
||||||
delete m_model;
|
delete m_model;
|
||||||
|
@ -97,6 +97,17 @@ void ModelInfo::setPromptBatchSize(int s)
|
|||||||
m_promptBatchSize = s;
|
m_promptBatchSize = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ModelInfo::contextLength() const
|
||||||
|
{
|
||||||
|
return MySettings::globalInstance()->modelContextLength(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ModelInfo::setContextLength(int l)
|
||||||
|
{
|
||||||
|
if (isClone) MySettings::globalInstance()->setModelContextLength(*this, l, isClone /*force*/);
|
||||||
|
m_contextLength = l;
|
||||||
|
}
|
||||||
|
|
||||||
double ModelInfo::repeatPenalty() const
|
double ModelInfo::repeatPenalty() const
|
||||||
{
|
{
|
||||||
return MySettings::globalInstance()->modelRepeatPenalty(*this);
|
return MySettings::globalInstance()->modelRepeatPenalty(*this);
|
||||||
@ -274,6 +285,7 @@ ModelList::ModelList()
|
|||||||
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
|
connect(MySettings::globalInstance(), &MySettings::topKChanged, this, &ModelList::updateDataForSettings);
|
||||||
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
|
connect(MySettings::globalInstance(), &MySettings::maxLengthChanged, this, &ModelList::updateDataForSettings);
|
||||||
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
|
connect(MySettings::globalInstance(), &MySettings::promptBatchSizeChanged, this, &ModelList::updateDataForSettings);
|
||||||
|
connect(MySettings::globalInstance(), &MySettings::contextLengthChanged, this, &ModelList::updateDataForSettings);
|
||||||
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
|
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyChanged, this, &ModelList::updateDataForSettings);
|
||||||
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);;
|
connect(MySettings::globalInstance(), &MySettings::repeatPenaltyTokensChanged, this, &ModelList::updateDataForSettings);;
|
||||||
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
|
connect(MySettings::globalInstance(), &MySettings::promptTemplateChanged, this, &ModelList::updateDataForSettings);
|
||||||
@ -525,6 +537,8 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const
|
|||||||
return info->maxLength();
|
return info->maxLength();
|
||||||
case PromptBatchSizeRole:
|
case PromptBatchSizeRole:
|
||||||
return info->promptBatchSize();
|
return info->promptBatchSize();
|
||||||
|
case ContextLengthRole:
|
||||||
|
return info->contextLength();
|
||||||
case RepeatPenaltyRole:
|
case RepeatPenaltyRole:
|
||||||
return info->repeatPenalty();
|
return info->repeatPenalty();
|
||||||
case RepeatPenaltyTokensRole:
|
case RepeatPenaltyTokensRole:
|
||||||
@ -740,6 +754,7 @@ QString ModelList::clone(const ModelInfo &model)
|
|||||||
updateData(id, ModelList::TopKRole, model.topK());
|
updateData(id, ModelList::TopKRole, model.topK());
|
||||||
updateData(id, ModelList::MaxLengthRole, model.maxLength());
|
updateData(id, ModelList::MaxLengthRole, model.maxLength());
|
||||||
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
|
updateData(id, ModelList::PromptBatchSizeRole, model.promptBatchSize());
|
||||||
|
updateData(id, ModelList::ContextLengthRole, model.contextLength());
|
||||||
updateData(id, ModelList::RepeatPenaltyRole, model.repeatPenalty());
|
updateData(id, ModelList::RepeatPenaltyRole, model.repeatPenalty());
|
||||||
updateData(id, ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens());
|
updateData(id, ModelList::RepeatPenaltyTokensRole, model.repeatPenaltyTokens());
|
||||||
updateData(id, ModelList::PromptTemplateRole, model.promptTemplate());
|
updateData(id, ModelList::PromptTemplateRole, model.promptTemplate());
|
||||||
@ -1106,6 +1121,8 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save)
|
|||||||
updateData(id, ModelList::MaxLengthRole, obj["maxLength"].toInt());
|
updateData(id, ModelList::MaxLengthRole, obj["maxLength"].toInt());
|
||||||
if (obj.contains("promptBatchSize"))
|
if (obj.contains("promptBatchSize"))
|
||||||
updateData(id, ModelList::PromptBatchSizeRole, obj["promptBatchSize"].toInt());
|
updateData(id, ModelList::PromptBatchSizeRole, obj["promptBatchSize"].toInt());
|
||||||
|
if (obj.contains("contextLength"))
|
||||||
|
updateData(id, ModelList::ContextLengthRole, obj["contextLength"].toInt());
|
||||||
if (obj.contains("repeatPenalty"))
|
if (obj.contains("repeatPenalty"))
|
||||||
updateData(id, ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble());
|
updateData(id, ModelList::RepeatPenaltyRole, obj["repeatPenalty"].toDouble());
|
||||||
if (obj.contains("repeatPenaltyTokens"))
|
if (obj.contains("repeatPenaltyTokens"))
|
||||||
@ -1198,6 +1215,8 @@ void ModelList::updateModelsFromSettings()
|
|||||||
const int maxLength = settings.value(g + "/maxLength").toInt();
|
const int maxLength = settings.value(g + "/maxLength").toInt();
|
||||||
Q_ASSERT(settings.contains(g + "/promptBatchSize"));
|
Q_ASSERT(settings.contains(g + "/promptBatchSize"));
|
||||||
const int promptBatchSize = settings.value(g + "/promptBatchSize").toInt();
|
const int promptBatchSize = settings.value(g + "/promptBatchSize").toInt();
|
||||||
|
Q_ASSERT(settings.contains(g + "/contextLength"));
|
||||||
|
const int contextLength = settings.value(g + "/contextLength").toInt();
|
||||||
Q_ASSERT(settings.contains(g + "/repeatPenalty"));
|
Q_ASSERT(settings.contains(g + "/repeatPenalty"));
|
||||||
const double repeatPenalty = settings.value(g + "/repeatPenalty").toDouble();
|
const double repeatPenalty = settings.value(g + "/repeatPenalty").toDouble();
|
||||||
Q_ASSERT(settings.contains(g + "/repeatPenaltyTokens"));
|
Q_ASSERT(settings.contains(g + "/repeatPenaltyTokens"));
|
||||||
@ -1216,6 +1235,7 @@ void ModelList::updateModelsFromSettings()
|
|||||||
updateData(id, ModelList::TopKRole, topK);
|
updateData(id, ModelList::TopKRole, topK);
|
||||||
updateData(id, ModelList::MaxLengthRole, maxLength);
|
updateData(id, ModelList::MaxLengthRole, maxLength);
|
||||||
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);
|
updateData(id, ModelList::PromptBatchSizeRole, promptBatchSize);
|
||||||
|
updateData(id, ModelList::ContextLengthRole, contextLength);
|
||||||
updateData(id, ModelList::RepeatPenaltyRole, repeatPenalty);
|
updateData(id, ModelList::RepeatPenaltyRole, repeatPenalty);
|
||||||
updateData(id, ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens);
|
updateData(id, ModelList::RepeatPenaltyTokensRole, repeatPenaltyTokens);
|
||||||
updateData(id, ModelList::PromptTemplateRole, promptTemplate);
|
updateData(id, ModelList::PromptTemplateRole, promptTemplate);
|
||||||
|
@ -39,6 +39,7 @@ struct ModelInfo {
|
|||||||
Q_PROPERTY(int topK READ topK WRITE setTopK)
|
Q_PROPERTY(int topK READ topK WRITE setTopK)
|
||||||
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
|
Q_PROPERTY(int maxLength READ maxLength WRITE setMaxLength)
|
||||||
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
|
Q_PROPERTY(int promptBatchSize READ promptBatchSize WRITE setPromptBatchSize)
|
||||||
|
Q_PROPERTY(int contextLength READ contextLength WRITE setContextLength)
|
||||||
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
|
Q_PROPERTY(double repeatPenalty READ repeatPenalty WRITE setRepeatPenalty)
|
||||||
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
|
Q_PROPERTY(int repeatPenaltyTokens READ repeatPenaltyTokens WRITE setRepeatPenaltyTokens)
|
||||||
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
|
Q_PROPERTY(QString promptTemplate READ promptTemplate WRITE setPromptTemplate)
|
||||||
@ -94,6 +95,8 @@ public:
|
|||||||
void setMaxLength(int l);
|
void setMaxLength(int l);
|
||||||
int promptBatchSize() const;
|
int promptBatchSize() const;
|
||||||
void setPromptBatchSize(int s);
|
void setPromptBatchSize(int s);
|
||||||
|
int contextLength() const;
|
||||||
|
void setContextLength(int l);
|
||||||
double repeatPenalty() const;
|
double repeatPenalty() const;
|
||||||
void setRepeatPenalty(double p);
|
void setRepeatPenalty(double p);
|
||||||
int repeatPenaltyTokens() const;
|
int repeatPenaltyTokens() const;
|
||||||
@ -112,6 +115,7 @@ private:
|
|||||||
int m_topK = 40;
|
int m_topK = 40;
|
||||||
int m_maxLength = 4096;
|
int m_maxLength = 4096;
|
||||||
int m_promptBatchSize = 128;
|
int m_promptBatchSize = 128;
|
||||||
|
int m_contextLength = 2048;
|
||||||
double m_repeatPenalty = 1.18;
|
double m_repeatPenalty = 1.18;
|
||||||
int m_repeatPenaltyTokens = 64;
|
int m_repeatPenaltyTokens = 64;
|
||||||
QString m_promptTemplate = "### Human:\n%1\n### Assistant:\n";
|
QString m_promptTemplate = "### Human:\n%1\n### Assistant:\n";
|
||||||
@ -227,6 +231,7 @@ public:
|
|||||||
TopKRole,
|
TopKRole,
|
||||||
MaxLengthRole,
|
MaxLengthRole,
|
||||||
PromptBatchSizeRole,
|
PromptBatchSizeRole,
|
||||||
|
ContextLengthRole,
|
||||||
RepeatPenaltyRole,
|
RepeatPenaltyRole,
|
||||||
RepeatPenaltyTokensRole,
|
RepeatPenaltyTokensRole,
|
||||||
PromptTemplateRole,
|
PromptTemplateRole,
|
||||||
@ -269,6 +274,7 @@ public:
|
|||||||
roles[TopKRole] = "topK";
|
roles[TopKRole] = "topK";
|
||||||
roles[MaxLengthRole] = "maxLength";
|
roles[MaxLengthRole] = "maxLength";
|
||||||
roles[PromptBatchSizeRole] = "promptBatchSize";
|
roles[PromptBatchSizeRole] = "promptBatchSize";
|
||||||
|
roles[ContextLengthRole] = "contextLength";
|
||||||
roles[RepeatPenaltyRole] = "repeatPenalty";
|
roles[RepeatPenaltyRole] = "repeatPenalty";
|
||||||
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
|
roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens";
|
||||||
roles[PromptTemplateRole] = "promptTemplate";
|
roles[PromptTemplateRole] = "promptTemplate";
|
||||||
|
@ -90,6 +90,7 @@ void MySettings::restoreModelDefaults(const ModelInfo &model)
|
|||||||
setModelTopK(model, model.m_topK);;
|
setModelTopK(model, model.m_topK);;
|
||||||
setModelMaxLength(model, model.m_maxLength);
|
setModelMaxLength(model, model.m_maxLength);
|
||||||
setModelPromptBatchSize(model, model.m_promptBatchSize);
|
setModelPromptBatchSize(model, model.m_promptBatchSize);
|
||||||
|
setModelContextLength(model, model.m_contextLength);
|
||||||
setModelRepeatPenalty(model, model.m_repeatPenalty);
|
setModelRepeatPenalty(model, model.m_repeatPenalty);
|
||||||
setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens);
|
setModelRepeatPenaltyTokens(model, model.m_repeatPenaltyTokens);
|
||||||
setModelPromptTemplate(model, model.m_promptTemplate);
|
setModelPromptTemplate(model, model.m_promptTemplate);
|
||||||
@ -280,6 +281,28 @@ void MySettings::setModelPromptBatchSize(const ModelInfo &m, int s, bool force)
|
|||||||
emit promptBatchSizeChanged(m);
|
emit promptBatchSizeChanged(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int MySettings::modelContextLength(const ModelInfo &m) const
|
||||||
|
{
|
||||||
|
QSettings setting;
|
||||||
|
setting.sync();
|
||||||
|
return setting.value(QString("model-%1").arg(m.id()) + "/contextLength", m.m_contextLength).toInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MySettings::setModelContextLength(const ModelInfo &m, int l, bool force)
|
||||||
|
{
|
||||||
|
if (modelContextLength(m) == l && !force)
|
||||||
|
return;
|
||||||
|
|
||||||
|
QSettings setting;
|
||||||
|
if (m.m_contextLength == l && !m.isClone)
|
||||||
|
setting.remove(QString("model-%1").arg(m.id()) + "/contextLength");
|
||||||
|
else
|
||||||
|
setting.setValue(QString("model-%1").arg(m.id()) + "/contextLength", l);
|
||||||
|
setting.sync();
|
||||||
|
if (!force)
|
||||||
|
emit contextLengthChanged(m);
|
||||||
|
}
|
||||||
|
|
||||||
double MySettings::modelRepeatPenalty(const ModelInfo &m) const
|
double MySettings::modelRepeatPenalty(const ModelInfo &m) const
|
||||||
{
|
{
|
||||||
QSettings setting;
|
QSettings setting;
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#ifndef MYSETTINGS_H
|
#ifndef MYSETTINGS_H
|
||||||
#define MYSETTINGS_H
|
#define MYSETTINGS_H
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QMutex>
|
#include <QMutex>
|
||||||
|
|
||||||
@ -59,6 +61,8 @@ public:
|
|||||||
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false);
|
Q_INVOKABLE void setModelPromptTemplate(const ModelInfo &m, const QString &t, bool force = false);
|
||||||
QString modelSystemPrompt(const ModelInfo &m) const;
|
QString modelSystemPrompt(const ModelInfo &m) const;
|
||||||
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false);
|
Q_INVOKABLE void setModelSystemPrompt(const ModelInfo &m, const QString &p, bool force = false);
|
||||||
|
int modelContextLength(const ModelInfo &m) const;
|
||||||
|
Q_INVOKABLE void setModelContextLength(const ModelInfo &m, int s, bool force = false);
|
||||||
|
|
||||||
// Application settings
|
// Application settings
|
||||||
int threadCount() const;
|
int threadCount() const;
|
||||||
@ -79,6 +83,8 @@ public:
|
|||||||
void setForceMetal(bool b);
|
void setForceMetal(bool b);
|
||||||
QString device() const;
|
QString device() const;
|
||||||
void setDevice(const QString &u);
|
void setDevice(const QString &u);
|
||||||
|
int32_t contextLength() const;
|
||||||
|
void setContextLength(int32_t value);
|
||||||
|
|
||||||
// Release/Download settings
|
// Release/Download settings
|
||||||
QString lastVersionStarted() const;
|
QString lastVersionStarted() const;
|
||||||
@ -114,6 +120,7 @@ Q_SIGNALS:
|
|||||||
void topKChanged(const ModelInfo &model);
|
void topKChanged(const ModelInfo &model);
|
||||||
void maxLengthChanged(const ModelInfo &model);
|
void maxLengthChanged(const ModelInfo &model);
|
||||||
void promptBatchSizeChanged(const ModelInfo &model);
|
void promptBatchSizeChanged(const ModelInfo &model);
|
||||||
|
void contextLengthChanged(const ModelInfo &model);
|
||||||
void repeatPenaltyChanged(const ModelInfo &model);
|
void repeatPenaltyChanged(const ModelInfo &model);
|
||||||
void repeatPenaltyTokensChanged(const ModelInfo &model);
|
void repeatPenaltyTokensChanged(const ModelInfo &model);
|
||||||
void promptTemplateChanged(const ModelInfo &model);
|
void promptTemplateChanged(const ModelInfo &model);
|
||||||
|
@ -349,13 +349,61 @@ MySettingsTab {
|
|||||||
rowSpacing: 10
|
rowSpacing: 10
|
||||||
columnSpacing: 10
|
columnSpacing: 10
|
||||||
|
|
||||||
|
Label {
|
||||||
|
id: contextLengthLabel
|
||||||
|
visible: !root.currentModelInfo.isChatGPT
|
||||||
|
text: qsTr("Context Length:")
|
||||||
|
font.pixelSize: theme.fontSizeLarge
|
||||||
|
color: theme.textColor
|
||||||
|
Layout.row: 0
|
||||||
|
Layout.column: 0
|
||||||
|
}
|
||||||
|
MyTextField {
|
||||||
|
id: contextLengthField
|
||||||
|
visible: !root.currentModelInfo.isChatGPT
|
||||||
|
text: root.currentModelInfo.contextLength
|
||||||
|
color: theme.textColor
|
||||||
|
font.pixelSize: theme.fontSizeLarge
|
||||||
|
ToolTip.text: qsTr("Maximum combined prompt/response tokens before information is lost.\nUsing more context than the model was trained on will yield poor results.\nNOTE: Does not take effect until you RESTART GPT4All or SWITCH MODELS.")
|
||||||
|
ToolTip.visible: hovered
|
||||||
|
Layout.row: 0
|
||||||
|
Layout.column: 1
|
||||||
|
validator: IntValidator {
|
||||||
|
bottom: 1
|
||||||
|
}
|
||||||
|
Connections {
|
||||||
|
target: MySettings
|
||||||
|
function onContextLengthChanged() {
|
||||||
|
contextLengthField.text = root.currentModelInfo.contextLength;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Connections {
|
||||||
|
target: root
|
||||||
|
function onCurrentModelInfoChanged() {
|
||||||
|
contextLengthField.text = root.currentModelInfo.contextLength;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
onEditingFinished: {
|
||||||
|
var val = parseInt(text)
|
||||||
|
if (!isNaN(val)) {
|
||||||
|
MySettings.setModelContextLength(root.currentModelInfo, val)
|
||||||
|
focus = false
|
||||||
|
} else {
|
||||||
|
text = root.currentModelInfo.contextLength
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Accessible.role: Accessible.EditableText
|
||||||
|
Accessible.name: contextLengthLabel.text
|
||||||
|
Accessible.description: ToolTip.text
|
||||||
|
}
|
||||||
|
|
||||||
Label {
|
Label {
|
||||||
id: tempLabel
|
id: tempLabel
|
||||||
text: qsTr("Temperature:")
|
text: qsTr("Temperature:")
|
||||||
color: theme.textColor
|
color: theme.textColor
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
Layout.row: 0
|
Layout.row: 1
|
||||||
Layout.column: 0
|
Layout.column: 2
|
||||||
}
|
}
|
||||||
|
|
||||||
MyTextField {
|
MyTextField {
|
||||||
@ -365,8 +413,8 @@ MySettingsTab {
|
|||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens.\nNOTE: Higher temperature gives more creative but less predictable outputs.")
|
ToolTip.text: qsTr("Temperature increases the chances of choosing less likely tokens.\nNOTE: Higher temperature gives more creative but less predictable outputs.")
|
||||||
ToolTip.visible: hovered
|
ToolTip.visible: hovered
|
||||||
Layout.row: 0
|
Layout.row: 1
|
||||||
Layout.column: 1
|
Layout.column: 3
|
||||||
validator: DoubleValidator {
|
validator: DoubleValidator {
|
||||||
locale: "C"
|
locale: "C"
|
||||||
}
|
}
|
||||||
@ -400,8 +448,8 @@ MySettingsTab {
|
|||||||
text: qsTr("Top P:")
|
text: qsTr("Top P:")
|
||||||
color: theme.textColor
|
color: theme.textColor
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
Layout.row: 0
|
Layout.row: 2
|
||||||
Layout.column: 2
|
Layout.column: 0
|
||||||
}
|
}
|
||||||
MyTextField {
|
MyTextField {
|
||||||
id: topPField
|
id: topPField
|
||||||
@ -410,8 +458,8 @@ MySettingsTab {
|
|||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen.\nNOTE: Prevents choosing highly unlikely tokens, aka Nucleus Sampling")
|
ToolTip.text: qsTr("Only the most likely tokens up to a total probability of top_p can be chosen.\nNOTE: Prevents choosing highly unlikely tokens, aka Nucleus Sampling")
|
||||||
ToolTip.visible: hovered
|
ToolTip.visible: hovered
|
||||||
Layout.row: 0
|
Layout.row: 2
|
||||||
Layout.column: 3
|
Layout.column: 1
|
||||||
validator: DoubleValidator {
|
validator: DoubleValidator {
|
||||||
locale: "C"
|
locale: "C"
|
||||||
}
|
}
|
||||||
@ -446,8 +494,8 @@ MySettingsTab {
|
|||||||
text: qsTr("Top K:")
|
text: qsTr("Top K:")
|
||||||
color: theme.textColor
|
color: theme.textColor
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
Layout.row: 1
|
Layout.row: 2
|
||||||
Layout.column: 0
|
Layout.column: 2
|
||||||
}
|
}
|
||||||
MyTextField {
|
MyTextField {
|
||||||
id: topKField
|
id: topKField
|
||||||
@ -457,8 +505,8 @@ MySettingsTab {
|
|||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from")
|
ToolTip.text: qsTr("Only the top K most likely tokens will be chosen from")
|
||||||
ToolTip.visible: hovered
|
ToolTip.visible: hovered
|
||||||
Layout.row: 1
|
Layout.row: 2
|
||||||
Layout.column: 1
|
Layout.column: 3
|
||||||
validator: IntValidator {
|
validator: IntValidator {
|
||||||
bottom: 1
|
bottom: 1
|
||||||
}
|
}
|
||||||
@ -493,7 +541,7 @@ MySettingsTab {
|
|||||||
text: qsTr("Max Length:")
|
text: qsTr("Max Length:")
|
||||||
color: theme.textColor
|
color: theme.textColor
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
Layout.row: 1
|
Layout.row: 0
|
||||||
Layout.column: 2
|
Layout.column: 2
|
||||||
}
|
}
|
||||||
MyTextField {
|
MyTextField {
|
||||||
@ -504,7 +552,7 @@ MySettingsTab {
|
|||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
ToolTip.text: qsTr("Maximum length of response in tokens")
|
ToolTip.text: qsTr("Maximum length of response in tokens")
|
||||||
ToolTip.visible: hovered
|
ToolTip.visible: hovered
|
||||||
Layout.row: 1
|
Layout.row: 0
|
||||||
Layout.column: 3
|
Layout.column: 3
|
||||||
validator: IntValidator {
|
validator: IntValidator {
|
||||||
bottom: 1
|
bottom: 1
|
||||||
@ -541,7 +589,7 @@ MySettingsTab {
|
|||||||
text: qsTr("Prompt Batch Size:")
|
text: qsTr("Prompt Batch Size:")
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
color: theme.textColor
|
color: theme.textColor
|
||||||
Layout.row: 2
|
Layout.row: 1
|
||||||
Layout.column: 0
|
Layout.column: 0
|
||||||
}
|
}
|
||||||
MyTextField {
|
MyTextField {
|
||||||
@ -552,7 +600,7 @@ MySettingsTab {
|
|||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
ToolTip.text: qsTr("Amount of prompt tokens to process at once.\nNOTE: Higher values can speed up reading prompts but will use more RAM")
|
ToolTip.text: qsTr("Amount of prompt tokens to process at once.\nNOTE: Higher values can speed up reading prompts but will use more RAM")
|
||||||
ToolTip.visible: hovered
|
ToolTip.visible: hovered
|
||||||
Layout.row: 2
|
Layout.row: 1
|
||||||
Layout.column: 1
|
Layout.column: 1
|
||||||
validator: IntValidator {
|
validator: IntValidator {
|
||||||
bottom: 1
|
bottom: 1
|
||||||
@ -588,8 +636,8 @@ MySettingsTab {
|
|||||||
text: qsTr("Repeat Penalty:")
|
text: qsTr("Repeat Penalty:")
|
||||||
color: theme.textColor
|
color: theme.textColor
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
Layout.row: 2
|
Layout.row: 3
|
||||||
Layout.column: 2
|
Layout.column: 0
|
||||||
}
|
}
|
||||||
MyTextField {
|
MyTextField {
|
||||||
id: repeatPenaltyField
|
id: repeatPenaltyField
|
||||||
@ -599,8 +647,8 @@ MySettingsTab {
|
|||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
|
ToolTip.text: qsTr("Amount to penalize repetitiveness of the output")
|
||||||
ToolTip.visible: hovered
|
ToolTip.visible: hovered
|
||||||
Layout.row: 2
|
Layout.row: 3
|
||||||
Layout.column: 3
|
Layout.column: 1
|
||||||
validator: DoubleValidator {
|
validator: DoubleValidator {
|
||||||
locale: "C"
|
locale: "C"
|
||||||
}
|
}
|
||||||
@ -636,7 +684,7 @@ MySettingsTab {
|
|||||||
color: theme.textColor
|
color: theme.textColor
|
||||||
font.pixelSize: theme.fontSizeLarge
|
font.pixelSize: theme.fontSizeLarge
|
||||||
Layout.row: 3
|
Layout.row: 3
|
||||||
Layout.column: 0
|
Layout.column: 2
|
||||||
}
|
}
|
||||||
MyTextField {
|
MyTextField {
|
||||||
id: repeatPenaltyTokenField
|
id: repeatPenaltyTokenField
|
||||||
@ -647,7 +695,7 @@ MySettingsTab {
|
|||||||
ToolTip.text: qsTr("How far back in output to apply repeat penalty")
|
ToolTip.text: qsTr("How far back in output to apply repeat penalty")
|
||||||
ToolTip.visible: hovered
|
ToolTip.visible: hovered
|
||||||
Layout.row: 3
|
Layout.row: 3
|
||||||
Layout.column: 1
|
Layout.column: 3
|
||||||
validator: IntValidator {
|
validator: IntValidator {
|
||||||
bottom: 1
|
bottom: 1
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user