diff --git a/gpt4all-backend/llama.cpp-mainline b/gpt4all-backend/llama.cpp-mainline index 4458a8ea..cbebf61c 160000 --- a/gpt4all-backend/llama.cpp-mainline +++ b/gpt4all-backend/llama.cpp-mainline @@ -1 +1 @@ -Subproject commit 4458a8eaf443e7fa0e764682d22213fa4fef90c3 +Subproject commit cbebf61ca7584e9709265395f0127ae7fc0f1882 diff --git a/gpt4all-backend/replit.cpp b/gpt4all-backend/replit.cpp index d4425ff0..f8cbf688 100644 --- a/gpt4all-backend/replit.cpp +++ b/gpt4all-backend/replit.cpp @@ -517,6 +517,7 @@ bool replit_model_load(const std::string & fname, std::istream &fin, replit_mode model.ctx_metal = ggml_metal_init(); void* data_ptr = ggml_get_mem_buffer(model.ctx); size_t data_size = ggml_get_mem_size(model.ctx); + const size_t max_size = ggml_get_max_tensor_size(model.ctx); #define GGML_CHECK_BUF(result) if (!(result)) { \ std::cerr << __func__ << ": failed to add buffer" << std::endl; \ @@ -524,12 +525,12 @@ bool replit_model_load(const std::string & fname, std::istream &fin, replit_mode return false; \ } - GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "data", data_ptr, data_size)); + GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "data", data_ptr, data_size, max_size)); GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "kv", ggml_get_mem_buffer(model.kv_self.ctx), - ggml_get_mem_size(model.kv_self.ctx))); - GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "eval", model.eval_buf, model.eval_buf_size)); - GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "scr0", model.scr0_buf, model.scr0_buf_size)); - GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "scr1", model.scr1_buf, model.scr1_buf_size)); + ggml_get_mem_size(model.kv_self.ctx), 0)); + GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "eval", model.eval_buf, model.eval_buf_size, 0)); + GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "scr0", model.scr0_buf, model.scr0_buf_size, 0)); + GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "scr1", model.scr1_buf, model.scr1_buf_size, 0)); #endif return true;