mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-06 11:00:48 +00:00
add requiredMem method to llmodel impls
most of these can just shortcut out of the model loading logic llama is a bit worse to deal with because we submodule it so I have to at least parse the hparams, and then I just use the size on disk as an estimate for the mem size (which seems reasonable since we mmap() the llama files anyway)
This commit is contained in:
@@ -158,8 +158,11 @@ static bool kv_cache_init(
|
||||
}
|
||||
|
||||
// load the model's weights from a stream
|
||||
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab) {
|
||||
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab, size_t * mem_req = nullptr) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
if(mem_req != nullptr) {
|
||||
*mem_req = 0;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
@@ -276,6 +279,19 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
if (mem_req != nullptr) {
|
||||
*mem_req += ctx_size;
|
||||
const int n_embd = model.hparams.n_embd;
|
||||
const int n_layer = model.hparams.n_layer;
|
||||
|
||||
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
*mem_req += (2u*n_elements*ggml_type_size(wtype) + 2_MiB);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
@@ -837,6 +853,15 @@ GPTJ::GPTJ()
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
|
||||
size_t GPTJ::requiredMem(const std::string &modelPath) {
|
||||
gptj_model dummy_model;
|
||||
gpt_vocab dummy_vocab;
|
||||
size_t mem_req;
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
gptj_model_load(modelPath, fin, dummy_model, dummy_vocab, &mem_req);
|
||||
return mem_req;
|
||||
}
|
||||
|
||||
bool GPTJ::loadModel(const std::string &modelPath) {
|
||||
std::mt19937 rng(time(NULL));
|
||||
d_ptr->rng = rng;
|
||||
|
Reference in New Issue
Block a user