mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-06 11:00:48 +00:00
backend: factor out common elements in model code (#1089)
* backend: factor out common structs in model code prepping to hack on these by hopefully making there be fewer places to fix the same bug rename * use common buffer wrapper instead of manual malloc * fix replit compile warnings
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include "gptj_impl.h"
|
||||
|
||||
#include "utils.h"
|
||||
#include "llmodel_shared.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -63,39 +64,6 @@ struct gptj_layer {
|
||||
struct ggml_tensor * c_mlp_proj_b;
|
||||
};
|
||||
|
||||
struct gptj_buffer {
|
||||
uint8_t * addr = NULL;
|
||||
size_t size = 0;
|
||||
|
||||
void resize(size_t size) {
|
||||
delete[] addr;
|
||||
addr = new uint8_t[size];
|
||||
this->size = size;
|
||||
}
|
||||
|
||||
~gptj_buffer() {
|
||||
fflush(stdout);
|
||||
delete[] addr;
|
||||
}
|
||||
};
|
||||
|
||||
struct gptj_kv_cache {
|
||||
struct ggml_tensor * k;
|
||||
struct ggml_tensor * v;
|
||||
|
||||
struct ggml_context * ctx = NULL;
|
||||
|
||||
gptj_buffer buf;
|
||||
|
||||
int n; // number of tokens currently in the cache
|
||||
|
||||
~gptj_kv_cache() {
|
||||
if (ctx) {
|
||||
ggml_free(ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct gptj_model {
|
||||
gptj_hparams hparams;
|
||||
|
||||
@@ -111,13 +79,13 @@ struct gptj_model {
|
||||
std::vector<gptj_layer> layers;
|
||||
|
||||
// key + value memory
|
||||
struct gptj_kv_cache kv_self;
|
||||
struct llm_kv_cache kv_self;
|
||||
|
||||
//
|
||||
struct ggml_context * ctx;
|
||||
std::map<std::string, struct ggml_tensor *> tensors;
|
||||
|
||||
gptj_buffer buf;
|
||||
llm_buffer buf;
|
||||
|
||||
~gptj_model() {
|
||||
if (ctx) {
|
||||
@@ -128,7 +96,7 @@ struct gptj_model {
|
||||
|
||||
static bool kv_cache_init(
|
||||
const struct gptj_hparams & hparams,
|
||||
struct gptj_kv_cache & cache,
|
||||
struct llm_kv_cache & cache,
|
||||
ggml_type wtype,
|
||||
int n_ctx) {
|
||||
const int n_embd = hparams.n_embd;
|
||||
|
Reference in New Issue
Block a user