mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-27 15:58:25 +00:00
backend: port GPT-J to GGUF
This commit is contained in:
parent
31b20f093a
commit
050e7f076e
@ -98,10 +98,9 @@ foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS)
|
||||
prepare_target(llamamodel-mainline llama-mainline)
|
||||
|
||||
if (NOT LLAMA_METAL)
|
||||
# FIXME: These need to be forward ported to latest ggml
|
||||
# add_library(gptj-${BUILD_VARIANT} SHARED
|
||||
# gptj.cpp utils.h utils.cpp llmodel_shared.cpp llmodel_shared.h)
|
||||
# prepare_target(gptj ggml-230511)
|
||||
add_library(gptj-${BUILD_VARIANT} SHARED
|
||||
gptj.cpp utils.h utils.cpp llmodel_shared.cpp llmodel_shared.h)
|
||||
prepare_target(gptj llama-mainline)
|
||||
|
||||
add_library(mpt-${BUILD_VARIANT} SHARED
|
||||
mpt.cpp utils.h utils.cpp llmodel_shared.cpp llmodel_shared.h)
|
||||
|
@ -8,7 +8,6 @@
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -42,7 +41,7 @@ struct gptj_hparams {
|
||||
int32_t n_head = 16;
|
||||
int32_t n_layer = 28;
|
||||
int32_t n_rot = 64;
|
||||
int32_t f16 = 1;
|
||||
float norm_eps = 1e-5;
|
||||
};
|
||||
|
||||
struct gptj_layer {
|
||||
@ -128,156 +127,113 @@ static bool kv_cache_init(
|
||||
return true;
|
||||
}
|
||||
|
||||
// load the model's weights from a stream
|
||||
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab, size_t * mem_req = nullptr) {
|
||||
// load the model's weights from a file path
|
||||
bool gptj_model_load(const std::string &fname, gptj_model & model, gpt_vocab & vocab, size_t * mem_req = nullptr) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
if(mem_req != nullptr) {
|
||||
*mem_req = 0;
|
||||
}
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 0x67676d6c) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
// create the ggml context
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ false,
|
||||
/*.ctx = */ &model.ctx,
|
||||
};
|
||||
|
||||
gguf_context *ggufctx = gguf_init_from_file(fname.c_str(), params);
|
||||
if (!ggufctx) {
|
||||
fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// load hparams
|
||||
{
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
|
||||
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
||||
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
||||
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
||||
fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
|
||||
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
bool ok = false;
|
||||
int keyidx;
|
||||
|
||||
do {
|
||||
keyidx = gguf_find_key(ggufctx, "gptj.context_length");
|
||||
if (keyidx == -1) { break; }
|
||||
hparams.n_ctx = gguf_get_val_u32(ggufctx, keyidx);
|
||||
|
||||
keyidx = gguf_find_key(ggufctx, "gptj.embedding_length");
|
||||
if (keyidx == -1) { break; }
|
||||
hparams.n_embd = gguf_get_val_u32(ggufctx, keyidx);
|
||||
|
||||
keyidx = gguf_find_key(ggufctx, "gptj.attention.head_count");
|
||||
if (keyidx == -1) { break; }
|
||||
hparams.n_head = gguf_get_val_u32(ggufctx, keyidx);
|
||||
|
||||
keyidx = gguf_find_key(ggufctx, "gptj.block_count");
|
||||
if (keyidx == -1) { break; }
|
||||
hparams.n_layer = gguf_get_val_u32(ggufctx, keyidx);
|
||||
|
||||
keyidx = gguf_find_key(ggufctx, "gptj.rope.dimension_count");
|
||||
if (keyidx == -1) { break; }
|
||||
hparams.n_rot = gguf_get_val_u32(ggufctx, keyidx);
|
||||
|
||||
keyidx = gguf_find_key(ggufctx, "gptj.attention.layer_norm_epsilon");
|
||||
if (keyidx == -1) { break; }
|
||||
hparams.norm_eps = gguf_get_val_f32(ggufctx, keyidx);
|
||||
|
||||
ok = true;
|
||||
} while (false);
|
||||
|
||||
if (!ok) {
|
||||
fprintf(stderr, "%s: required hparam missing!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: n_rot = %d\n", __func__, hparams.n_rot);
|
||||
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
||||
}
|
||||
|
||||
// load vocab
|
||||
{
|
||||
int32_t n_vocab = 0;
|
||||
fin.read((char *) &n_vocab, sizeof(n_vocab));
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
if (n_vocab != model.hparams.n_vocab) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
||||
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
||||
int keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.model");
|
||||
if (keyidx == -1) {
|
||||
fprintf(stderr, "%s: tokenizer model not found!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
if (strcmp(gguf_get_val_str(ggufctx, keyidx), "gpt2") != 0) {
|
||||
fprintf(stderr, "%s: tokenizer model not supported!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
int tokens_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.tokens");
|
||||
if (tokens_keyidx == -1) {
|
||||
fprintf(stderr, "%s: gpt2 tokenizer vocab not found!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
word.resize(len);
|
||||
fin.read((char *) word.data(), len);
|
||||
hparams.n_vocab = gguf_get_arr_n(ggufctx, tokens_keyidx);
|
||||
printf("%s: gpt2 tokenizer vocab = %d\n", __func__, int(hparams.n_vocab));
|
||||
|
||||
for (int i = 0; i < hparams.n_vocab; i++) {
|
||||
std::string word = gguf_get_arr_str(ggufctx, tokens_keyidx, i);
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
}
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
||||
// in order to save memory and also to speed up the computation
|
||||
ggml_type wtype = GGML_TYPE_COUNT;
|
||||
switch (model.hparams.f16) {
|
||||
case 0: wtype = GGML_TYPE_F32; break;
|
||||
case 1: wtype = GGML_TYPE_F16; break;
|
||||
case 2: wtype = GGML_TYPE_Q4_0; break;
|
||||
case 3: wtype = GGML_TYPE_Q4_1; break;
|
||||
case 5: wtype = GGML_TYPE_Q4_2; break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
|
||||
__func__, fname.c_str(), model.hparams.f16);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
|
||||
|
||||
ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // wte
|
||||
|
||||
ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // lmh_g
|
||||
ctx_size += n_vocab*ggml_type_sizef(GGML_TYPE_F32); // lmh_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_q_proj_w
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_k_proj_w
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_v_proj_w
|
||||
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
|
||||
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||
|
||||
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
|
||||
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
|
||||
|
||||
ctx_size += (5 + 10*n_layer)*256; // object overhead
|
||||
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
size_t ctx_size = ggml_get_mem_size(ctx);
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
|
||||
|
||||
if (mem_req != nullptr) {
|
||||
*mem_req += ctx_size;
|
||||
const int n_embd = model.hparams.n_embd;
|
||||
const int n_layer = model.hparams.n_layer;
|
||||
|
||||
const int64_t n_mem = (int64_t)n_layer*model.hparams.n_ctx;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
*mem_req += (2u*n_elements*ggml_type_size(wtype) + 2_MiB);
|
||||
*mem_req = ctx_size;
|
||||
gguf_free(ggufctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// prepare memory for the weights
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
@ -288,56 +244,37 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.wte = ggml_get_tensor(ctx, "token_embd.weight");
|
||||
|
||||
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model.ln_f_g = ggml_get_tensor(ctx, "output_norm.weight");
|
||||
model.ln_f_b = ggml_get_tensor(ctx, "output_norm.bias");
|
||||
|
||||
model.lmh_g = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
|
||||
model.lmh_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab);
|
||||
model.lmh_g = ggml_get_tensor(ctx, "output.weight");
|
||||
model.lmh_b = ggml_get_tensor(ctx, "output.bias");
|
||||
|
||||
// map by name
|
||||
model.tensors["transformer.wte.weight"] = model.wte;
|
||||
|
||||
model.tensors["transformer.ln_f.weight"] = model.ln_f_g;
|
||||
model.tensors["transformer.ln_f.bias"] = model.ln_f_b;
|
||||
|
||||
model.tensors["lm_head.weight"] = model.lmh_g;
|
||||
model.tensors["lm_head.bias"] = model.lmh_b;
|
||||
auto name = [](int i, std::string n) {
|
||||
static std::string key;
|
||||
key = "blk." + std::to_string(i) + "." + n;
|
||||
return key.c_str();
|
||||
};
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.ln_1_g = ggml_get_tensor(ctx, name(i, "attn_norm.weight"));
|
||||
layer.ln_1_b = ggml_get_tensor(ctx, name(i, "attn_norm.bias"));
|
||||
|
||||
layer.c_attn_q_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_k_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_v_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_q_proj_w = ggml_get_tensor(ctx, name(i, "attn_q.weight"));
|
||||
layer.c_attn_k_proj_w = ggml_get_tensor(ctx, name(i, "attn_k.weight"));
|
||||
layer.c_attn_v_proj_w = ggml_get_tensor(ctx, name(i, "attn_v.weight"));
|
||||
|
||||
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.c_attn_proj_w = ggml_get_tensor(ctx, name(i, "attn_output.weight"));
|
||||
|
||||
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
|
||||
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
|
||||
layer.c_mlp_fc_w = ggml_get_tensor(ctx, name(i, "ffn_up.weight"));
|
||||
layer.c_mlp_fc_b = ggml_get_tensor(ctx, name(i, "ffn_up.bias"));
|
||||
|
||||
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
|
||||
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".ln_1.weight"] = layer.ln_1_g;
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".ln_1.bias"] = layer.ln_1_b;
|
||||
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".attn.q_proj.weight"] = layer.c_attn_q_proj_w;
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".attn.k_proj.weight"] = layer.c_attn_k_proj_w;
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".attn.v_proj.weight"] = layer.c_attn_v_proj_w;
|
||||
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".attn.out_proj.weight"] = layer.c_attn_proj_w;
|
||||
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.weight"] = layer.c_mlp_fc_w;
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.bias"] = layer.c_mlp_fc_b;
|
||||
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"] = layer.c_mlp_proj_w;
|
||||
model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.bias"] = layer.c_mlp_proj_b;
|
||||
layer.c_mlp_proj_w = ggml_get_tensor(ctx, name(i, "ffn_down.weight"));
|
||||
layer.c_mlp_proj_b = ggml_get_tensor(ctx, name(i, "ffn_down.bias"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -354,113 +291,12 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
printf("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// load weights
|
||||
{
|
||||
int n_tensors = 0;
|
||||
size_t total_size = 0;
|
||||
|
||||
printf("%s: ", __func__);
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%" PRId64 ", %" PRId64 "], expected [%d, %d]\n",
|
||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (0) {
|
||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
||||
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
size_t bpe = 0;
|
||||
|
||||
switch (ftype) {
|
||||
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
|
||||
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
|
||||
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
|
||||
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
|
||||
//printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
total_size += ggml_nbytes(tensor);
|
||||
if (++n_tensors % 8 == 0) {
|
||||
printf(".");
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
printf(" done\n");
|
||||
|
||||
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
|
||||
}
|
||||
|
||||
model.scr0_buf.resize(256u * 1024 * 1024);
|
||||
model.scr1_buf.resize(256u * 1024 * 1024);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// load the model's weights from a file path
|
||||
bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & vocab) {
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool loaded = gptj_model_load(fname, fin, model, vocab);
|
||||
fin.close();
|
||||
return loaded;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - model: the model
|
||||
@ -513,7 +349,6 @@ bool gptj_eval(
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
@ -526,7 +361,7 @@ bool gptj_eval(
|
||||
ggml_set_scratch(ctx0, {0, model.scr0_buf.size, model.scr0_buf.addr, });
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpL);
|
||||
cur = ggml_norm(ctx0, inpL, model.hparams.norm_eps);
|
||||
|
||||
// cur = ln_1_g*cur + ln_1_b
|
||||
cur = ggml_add(ctx0,
|
||||
@ -560,7 +395,7 @@ bool gptj_eval(
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
||||
n_past, n_rot, 0),
|
||||
n_past, n_rot, 0, 0),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
@ -570,7 +405,7 @@ bool gptj_eval(
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
n_past, n_rot, 1),
|
||||
n_past, n_rot, 1, 0),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
@ -656,7 +491,7 @@ bool gptj_eval(
|
||||
|
||||
// norm
|
||||
{
|
||||
inpL = ggml_norm(ctx0, inpL);
|
||||
inpL = ggml_norm(ctx0, inpL, model.hparams.norm_eps);
|
||||
|
||||
// inpL = ln_f_g*inpL + ln_f_b
|
||||
inpL = ggml_add(ctx0,
|
||||
@ -680,9 +515,18 @@ bool gptj_eval(
|
||||
// logits -> probs
|
||||
//inpL = ggml_soft_max(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
// run the computation
|
||||
{
|
||||
std::unique_ptr<uint8_t []> data;
|
||||
auto plan = ggml_graph_plan(&gf, n_threads);
|
||||
if (plan.work_size > 0) {
|
||||
data.reset(new uint8_t[plan.work_size]);
|
||||
plan.work_data = data.get();
|
||||
}
|
||||
ggml_graph_compute(&gf, &plan);
|
||||
}
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
@ -836,8 +680,7 @@ size_t GPTJ::requiredMem(const std::string &modelPath) {
|
||||
gptj_model dummy_model;
|
||||
gpt_vocab dummy_vocab;
|
||||
size_t mem_req;
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
gptj_model_load(modelPath, fin, dummy_model, dummy_vocab, &mem_req);
|
||||
gptj_model_load(modelPath, dummy_model, dummy_vocab, &mem_req);
|
||||
return mem_req;
|
||||
}
|
||||
|
||||
@ -845,10 +688,8 @@ bool GPTJ::loadModel(const std::string &modelPath) {
|
||||
std::mt19937 rng(time(NULL));
|
||||
d_ptr->rng = rng;
|
||||
|
||||
auto fin = std::ifstream(modelPath, std::ios::binary);
|
||||
|
||||
// load the model
|
||||
if (!gptj_model_load(modelPath, fin, *d_ptr->model, d_ptr->vocab)) {
|
||||
if (!gptj_model_load(modelPath, *d_ptr->model, d_ptr->vocab)) {
|
||||
std::cerr << "GPT-J ERROR: failed to load model from " << modelPath;
|
||||
return false;
|
||||
}
|
||||
@ -939,6 +780,16 @@ const std::vector<LLModel::Token> &GPTJ::endTokens() const
|
||||
return fres;
|
||||
}
|
||||
|
||||
std::string get_arch_name(gguf_context *ctx_gguf) {
|
||||
std::string arch_name;
|
||||
const int kid = gguf_find_key(ctx_gguf, "general.architecture");
|
||||
enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid);
|
||||
if (ktype != GGUF_TYPE_STRING) {
|
||||
throw std::runtime_error("ERROR: Can't get general architecture from gguf file.");
|
||||
}
|
||||
return gguf_get_val_str(ctx_gguf, kid);
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define DLL_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
@ -958,15 +809,21 @@ DLL_EXPORT const char *get_build_variant() {
|
||||
return GGML_BUILD_VARIANT;
|
||||
}
|
||||
|
||||
DLL_EXPORT bool magic_match(std::istream& f) {
|
||||
uint32_t magic = 0;
|
||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
gptj_hparams hparams;
|
||||
f.read(reinterpret_cast<char*>(&hparams), sizeof(hparams));
|
||||
if (!(hparams.n_vocab >= 50300 && hparams.n_vocab <= 50400)) {
|
||||
return false; // not a gptj.
|
||||
}
|
||||
return magic == 0x67676d6c;
|
||||
DLL_EXPORT bool magic_match(const char * fname) {
|
||||
struct ggml_context * ctx_meta = NULL;
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ true,
|
||||
/*.ctx = */ &ctx_meta,
|
||||
};
|
||||
gguf_context *ctx_gguf = gguf_init_from_file(fname, params);
|
||||
if (!ctx_gguf)
|
||||
return false;
|
||||
|
||||
bool isValid = gguf_get_version(ctx_gguf) <= 2;
|
||||
isValid = isValid && get_arch_name(ctx_gguf) == "gptj";
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
return isValid;
|
||||
}
|
||||
|
||||
DLL_EXPORT LLModel *construct() {
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
184
gpt4all-backend/scripts/convert_gptj_to_gguf.py
Normal file
184
gpt4all-backend/scripts/convert_gptj_to_gguf.py
Normal file
@ -0,0 +1,184 @@
|
||||
# Convert GPT-J-6B h5 transformer model to ggml format
|
||||
#
|
||||
# Load the model using GPTJForCausalLM.
|
||||
# Iterate over all variables and write them to a binary file.
|
||||
#
|
||||
# For each variable, write the following:
|
||||
# - Number of dimensions (int)
|
||||
# - Name length (int)
|
||||
# - Dimensions (int[n_dims])
|
||||
# - Name (char[name_length])
|
||||
# - Data (float[n_dims])
|
||||
#
|
||||
# By default, the bigger matrices are converted to 16-bit floats.
|
||||
# This can be disabled by adding the "ftype" CLI argument.
|
||||
#
|
||||
# At the start of the ggml file we write the model parameters
|
||||
# and vocabulary.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import struct
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import gguf
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, GPTJConfig, GPTJForCausalLM
|
||||
|
||||
|
||||
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
return dict(zip(bs, (chr(n) for n in cs)))
|
||||
|
||||
|
||||
if not 2 <= len(sys.argv) < 4:
|
||||
print("Usage: python {} dir-model [ftype]\n".format(Path(__file__).name))
|
||||
print(" ftype == 0 -> float32")
|
||||
print(" ftype == 1 -> float16")
|
||||
sys.exit(1)
|
||||
|
||||
# output in the same directory as the model
|
||||
dir_model = Path(sys.argv[1])
|
||||
fname_out = dir_model / "ggml-model.gguf"
|
||||
|
||||
# possible data types
|
||||
# ftype == 0 -> float32
|
||||
# ftype == 1 -> float16
|
||||
#
|
||||
# map from ftype to string
|
||||
ftype_str = ["f32", "f16"]
|
||||
|
||||
ftype = 1
|
||||
if len(sys.argv) > 2:
|
||||
ftype = int(sys.argv[2])
|
||||
if ftype < 0 or ftype > 1:
|
||||
print("Invalid ftype: " + str(ftype))
|
||||
sys.exit(1)
|
||||
|
||||
fname_out = dir_model / ("ggml-model-" + ftype_str[ftype] + ".gguf")
|
||||
|
||||
|
||||
ARCH = gguf.MODEL_ARCH.GPTJ
|
||||
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
||||
|
||||
print("gguf: get model metadata")
|
||||
|
||||
config = GPTJConfig(dir_model)
|
||||
|
||||
block_count = config.n_layer
|
||||
gguf_writer.add_name("GPT-J")
|
||||
gguf_writer.add_context_length(config.n_positions)
|
||||
gguf_writer.add_embedding_length(config.n_embd)
|
||||
gguf_writer.add_block_count(block_count)
|
||||
gguf_writer.add_head_count(config.n_head)
|
||||
gguf_writer.add_rope_dimension_count(config.rotary_dim)
|
||||
gguf_writer.add_layer_norm_eps(config.layer_norm_epsilon)
|
||||
gguf_writer.add_file_type(ftype)
|
||||
|
||||
print("gguf: get gpt2 tokenizer vocab")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
|
||||
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
|
||||
tokens: list[bytearray] = []
|
||||
|
||||
for i in range(config.vocab_size):
|
||||
if i in reverse_vocab:
|
||||
try:
|
||||
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
|
||||
except KeyError:
|
||||
text = bytearray()
|
||||
for c in reverse_vocab[i]:
|
||||
if ord(c) < 256: # single byte character
|
||||
text.append(byte_decoder[c])
|
||||
else: # multibyte special token character
|
||||
text.extend(c.encode('utf-8'))
|
||||
else:
|
||||
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
|
||||
pad_token = f"[PAD{i}]".encode("utf8")
|
||||
text = bytearray(pad_token)
|
||||
|
||||
tokens.append(text)
|
||||
|
||||
|
||||
gguf_writer.add_tokenizer_model("gpt2")
|
||||
gguf_writer.add_token_list(tokens)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(gguf_writer)
|
||||
|
||||
print("gguf: get tensor metadata")
|
||||
|
||||
model = GPTJForCausalLM.from_pretrained(dir_model, config=config, low_cpu_mem_usage=True)
|
||||
#print (model)
|
||||
|
||||
tensor_map = gguf.get_tensor_name_map(ARCH, block_count)
|
||||
|
||||
list_vars = model.state_dict()
|
||||
#print (list_vars)
|
||||
|
||||
for name in list_vars.keys():
|
||||
data = list_vars[name].squeeze().numpy()
|
||||
print("Processing variable:", name, "with shape:", data.shape)
|
||||
|
||||
# we don't need these
|
||||
if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
|
||||
print(" Skipping variable:", name)
|
||||
continue
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||
ftype_cur = 0
|
||||
if ftype == 1 and name[-7:] == ".weight" and n_dims == 2:
|
||||
print(" Converting to float16")
|
||||
data = data.astype(np.float16)
|
||||
ftype_cur = 1
|
||||
elif ftype == 1 or data.dtype != np.float32:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print("Can not map tensor '" + name + "'")
|
||||
sys.exit()
|
||||
|
||||
gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
print("gguf: write header")
|
||||
gguf_writer.write_header_to_file()
|
||||
print("gguf: write metadata")
|
||||
gguf_writer.write_kv_data_to_file()
|
||||
print("gguf: write tensors")
|
||||
gguf_writer.write_tensors_to_file()
|
||||
|
||||
gguf_writer.close()
|
||||
|
||||
print(f"gguf: model successfully exported to '{fname_out}'")
|
||||
print()
|
Loading…
Reference in New Issue
Block a user