mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-07 11:30:05 +00:00
Update to latest llama.cpp (#1706)
This commit is contained in:
@@ -71,9 +71,10 @@ static int llama_sample_top_p_top_k(
|
||||
int top_k,
|
||||
float top_p,
|
||||
float temp,
|
||||
float repeat_penalty) {
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
float repeat_penalty,
|
||||
int32_t pos) {
|
||||
auto logits = llama_get_logits_ith(ctx, pos);
|
||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
// Populate initial list of all candidates
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
@@ -82,21 +83,23 @@ static int llama_sample_top_p_top_k(
|
||||
}
|
||||
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
|
||||
// Sample repeat penalty
|
||||
llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty);
|
||||
llama_sample_repetition_penalties(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty, 0.0f, 0.0f);
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
return llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
|
||||
struct LLamaPrivate {
|
||||
const std::string modelPath;
|
||||
bool modelLoaded;
|
||||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
llama_context_params params;
|
||||
llama_model_params model_params;
|
||||
llama_context_params ctx_params;
|
||||
int64_t n_threads = 0;
|
||||
std::vector<LLModel::Token> end_tokens;
|
||||
};
|
||||
@@ -142,37 +145,46 @@ size_t LLamaModel::requiredMem(const std::string &modelPath) {
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
{
|
||||
// load the model
|
||||
d_ptr->params = llama_context_default_params();
|
||||
|
||||
gpt_params params;
|
||||
d_ptr->params.n_ctx = 2048;
|
||||
d_ptr->params.seed = params.seed;
|
||||
d_ptr->params.f16_kv = params.memory_f16;
|
||||
d_ptr->params.use_mmap = params.use_mmap;
|
||||
|
||||
// load the model
|
||||
d_ptr->model_params = llama_model_default_params();
|
||||
|
||||
d_ptr->model_params.use_mmap = params.use_mmap;
|
||||
#if defined (__APPLE__)
|
||||
d_ptr->params.use_mlock = true;
|
||||
d_ptr->model_params.use_mlock = true;
|
||||
#else
|
||||
d_ptr->params.use_mlock = params.use_mlock;
|
||||
d_ptr->model_params.use_mlock = params.use_mlock;
|
||||
#endif
|
||||
|
||||
d_ptr->ctx_params = llama_context_default_params();
|
||||
|
||||
d_ptr->ctx_params.n_ctx = 2048;
|
||||
d_ptr->ctx_params.seed = params.seed;
|
||||
d_ptr->ctx_params.f16_kv = params.memory_f16;
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
|
||||
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (llama_verbose()) {
|
||||
std::cerr << "llama.cpp: using Metal" << std::endl;
|
||||
}
|
||||
// metal always runs the whole model if n_gpu_layers is not 0, at least
|
||||
// currently
|
||||
d_ptr->params.n_gpu_layers = 1;
|
||||
d_ptr->model_params.n_gpu_layers = 1;
|
||||
#endif
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
if (ggml_vk_has_device()) {
|
||||
// vulkan always runs the whole model if n_gpu_layers is not 0, at least
|
||||
// currently
|
||||
d_ptr->params.n_gpu_layers = 1;
|
||||
d_ptr->model_params.n_gpu_layers = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params);
|
||||
if (!d_ptr->ctx) {
|
||||
d_ptr->model = llama_load_model_from_file_gpt4all(modelPath.c_str(), &d_ptr->model_params);
|
||||
if (!d_ptr->model) {
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
// Explicitly free the device so next load it doesn't use it
|
||||
ggml_vk_free_device();
|
||||
@@ -181,7 +193,17 @@ bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
return false;
|
||||
}
|
||||
|
||||
d_ptr->end_tokens = {llama_token_eos(d_ptr->ctx)};
|
||||
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
|
||||
if (!d_ptr->ctx) {
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
// Explicitly free the device so next load it doesn't use it
|
||||
ggml_vk_free_device();
|
||||
#endif
|
||||
std::cerr << "LLAMA ERROR: failed to init context for model " << modelPath << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
d_ptr->end_tokens = {llama_token_eos(d_ptr->model)};
|
||||
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
if (ggml_vk_has_device()) {
|
||||
@@ -189,7 +211,6 @@ bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
}
|
||||
#endif
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->modelLoaded = true;
|
||||
fflush(stderr);
|
||||
return true;
|
||||
@@ -197,6 +218,7 @@ bool LLamaModel::loadModel(const std::string &modelPath)
|
||||
|
||||
void LLamaModel::setThreadCount(int32_t n_threads) {
|
||||
d_ptr->n_threads = n_threads;
|
||||
llama_set_n_threads(d_ptr->ctx, n_threads, n_threads);
|
||||
}
|
||||
|
||||
int32_t LLamaModel::threadCount() const {
|
||||
@@ -208,6 +230,7 @@ LLamaModel::~LLamaModel()
|
||||
if (d_ptr->ctx) {
|
||||
llama_free(d_ptr->ctx);
|
||||
}
|
||||
llama_free_model(d_ptr->model);
|
||||
}
|
||||
|
||||
bool LLamaModel::isModelLoaded() const
|
||||
@@ -233,16 +256,17 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str) const
|
||||
{
|
||||
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos(d_ptr->ctx));
|
||||
const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos(d_ptr->model));
|
||||
std::vector<LLModel::Token> fres(str.size()+4);
|
||||
auto fres_len = llama_tokenize(d_ptr->ctx, str.c_str(), str.length(), fres.data(), fres.size(), useBOS);
|
||||
// TODO(cebtenzzre): we may want to use special=true here to process special tokens
|
||||
auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, false);
|
||||
fres.resize(fres_len);
|
||||
return fres;
|
||||
}
|
||||
|
||||
std::string LLamaModel::tokenToString(Token id) const
|
||||
{
|
||||
return llama_token_to_str(d_ptr->ctx, id);
|
||||
return llama_token_to_piece(d_ptr->ctx, id);
|
||||
}
|
||||
|
||||
LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
|
||||
@@ -251,12 +275,30 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
|
||||
return llama_sample_top_p_top_k(d_ptr->ctx,
|
||||
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
||||
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||
promptCtx.repeat_penalty);
|
||||
promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1);
|
||||
}
|
||||
|
||||
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
|
||||
{
|
||||
return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0;
|
||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||
|
||||
batch.n_tokens = tokens.size();
|
||||
ctx.n_last_batch_tokens = tokens.size();
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token [i] = tokens[i];
|
||||
batch.pos [i] = ctx.n_past + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i][0] = 0;
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
int res = llama_decode(d_ptr->ctx, batch);
|
||||
llama_batch_free(batch);
|
||||
return res == 0;
|
||||
}
|
||||
|
||||
int32_t LLamaModel::contextLength() const
|
||||
|
Reference in New Issue
Block a user