backend: factor out common elements in model code (#1089)

* backend: factor out common structs in model code

prepping to hack on these by hopefully making there be fewer places to fix the same bug

rename

* use common buffer wrapper instead of manual malloc

* fix replit compile warnings
This commit is contained in:
Aaron Miller
2023-06-28 17:35:07 -07:00
committed by GitHub
parent 285aa50b60
commit 8d19ef3909
6 changed files with 81 additions and 202 deletions

View File

@@ -2,6 +2,7 @@
#include "gptj_impl.h"
#include "utils.h"
#include "llmodel_shared.h"
#include <cassert>
#include <cmath>
@@ -63,39 +64,6 @@ struct gptj_layer {
struct ggml_tensor * c_mlp_proj_b;
};
struct gptj_buffer {
uint8_t * addr = NULL;
size_t size = 0;
void resize(size_t size) {
delete[] addr;
addr = new uint8_t[size];
this->size = size;
}
~gptj_buffer() {
fflush(stdout);
delete[] addr;
}
};
struct gptj_kv_cache {
struct ggml_tensor * k;
struct ggml_tensor * v;
struct ggml_context * ctx = NULL;
gptj_buffer buf;
int n; // number of tokens currently in the cache
~gptj_kv_cache() {
if (ctx) {
ggml_free(ctx);
}
}
};
struct gptj_model {
gptj_hparams hparams;
@@ -111,13 +79,13 @@ struct gptj_model {
std::vector<gptj_layer> layers;
// key + value memory
struct gptj_kv_cache kv_self;
struct llm_kv_cache kv_self;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
gptj_buffer buf;
llm_buffer buf;
~gptj_model() {
if (ctx) {
@@ -128,7 +96,7 @@ struct gptj_model {
static bool kv_cache_init(
const struct gptj_hparams & hparams,
struct gptj_kv_cache & cache,
struct llm_kv_cache & cache,
ggml_type wtype,
int n_ctx) {
const int n_embd = hparams.n_embd;