Implement configurable context length (#1749)

This commit is contained in:
Jared Van Bortel
2023-12-16 17:58:15 -05:00
committed by GitHub
parent 7aa0f779de
commit d1c56b8b28
31 changed files with 291 additions and 135 deletions

View File

@@ -120,7 +120,8 @@ struct llama_file_hparams {
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
};
size_t LLamaModel::requiredMem(const std::string &modelPath) {
size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx) {
// TODO(cebtenzzre): update to GGUF
auto fin = std::ifstream(modelPath, std::ios::binary);
fin.seekg(0, std::ios_base::end);
size_t filesize = fin.tellg();
@@ -137,40 +138,31 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) {
fin.read(reinterpret_cast<char*>(&hparams.n_layer), sizeof(hparams.n_layer));
fin.read(reinterpret_cast<char*>(&hparams.n_rot), sizeof(hparams.n_rot));
fin.read(reinterpret_cast<char*>(&hparams.ftype), sizeof(hparams.ftype));
const size_t n_ctx = 2048;
const size_t kvcache_element_size = 2; // fp16
const size_t est_kvcache_size = hparams.n_embd * hparams.n_layer * 2u * n_ctx * kvcache_element_size;
return filesize + est_kvcache_size;
}
bool LLamaModel::loadModel(const std::string &modelPath)
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx)
{
gpt_params params;
// load the model
if (n_ctx < 8) {
std::cerr << "warning: minimum context size is 8, using minimum size.\n";
n_ctx = 8;
}
// -- load the model --
d_ptr->model_params = llama_model_default_params();
d_ptr->model_params.use_mmap = params.use_mmap;
d_ptr->model_params.use_mmap = params.use_mmap;
#if defined (__APPLE__)
d_ptr->model_params.use_mlock = true;
d_ptr->model_params.use_mlock = true;
#else
d_ptr->model_params.use_mlock = params.use_mlock;
d_ptr->model_params.use_mlock = params.use_mlock;
#endif
d_ptr->ctx_params = llama_context_default_params();
d_ptr->ctx_params.n_ctx = 2048;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.f16_kv = params.memory_f16;
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
// that we want this many logits so the state serializes consistently.
d_ptr->ctx_params.logits_all = true;
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
#ifdef GGML_USE_METAL
if (llama_verbose()) {
std::cerr << "llama.cpp: using Metal" << std::endl;
@@ -197,6 +189,28 @@ bool LLamaModel::loadModel(const std::string &modelPath)
return false;
}
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
if (n_ctx > n_ctx_train) {
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
<< n_ctx << " specified)\n";
}
// -- initialize the context --
d_ptr->ctx_params = llama_context_default_params();
d_ptr->ctx_params.n_ctx = n_ctx;
d_ptr->ctx_params.seed = params.seed;
d_ptr->ctx_params.f16_kv = params.memory_f16;
// The new batch API provides space for n_vocab*n_tokens logits. Tell llama.cpp early
// that we want this many logits so the state serializes consistently.
d_ptr->ctx_params.logits_all = true;
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
if (!d_ptr->ctx) {
#ifdef GGML_USE_KOMPUTE