mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-05 02:20:28 +00:00
add min_p sampling parameter (#2014)
Signed-off-by: Christopher Barrera <cb@arda.tx.rr.com> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
@@ -64,6 +64,7 @@ static int llama_sample_top_p_top_k(
|
||||
int last_n_tokens_size,
|
||||
int top_k,
|
||||
float top_p,
|
||||
float min_p,
|
||||
float temp,
|
||||
float repeat_penalty,
|
||||
int32_t pos) {
|
||||
@@ -83,6 +84,7 @@ static int llama_sample_top_p_top_k(
|
||||
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_min_p(ctx, &candidates_p, min_p, 1);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
return llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
@@ -392,7 +394,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
|
||||
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
|
||||
return llama_sample_top_p_top_k(d_ptr->ctx,
|
||||
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
||||
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp,
|
||||
promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1);
|
||||
}
|
||||
|
||||
|
@@ -66,6 +66,7 @@ public:
|
||||
int32_t n_predict = 200;
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float min_p = 0.0f;
|
||||
float temp = 0.9f;
|
||||
int32_t n_batch = 9;
|
||||
float repeat_penalty = 1.10f;
|
||||
|
@@ -134,6 +134,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
wrapper->promptContext.n_predict = ctx->n_predict;
|
||||
wrapper->promptContext.top_k = ctx->top_k;
|
||||
wrapper->promptContext.top_p = ctx->top_p;
|
||||
wrapper->promptContext.min_p = ctx->min_p;
|
||||
wrapper->promptContext.temp = ctx->temp;
|
||||
wrapper->promptContext.n_batch = ctx->n_batch;
|
||||
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
|
||||
@@ -156,6 +157,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
ctx->n_predict = wrapper->promptContext.n_predict;
|
||||
ctx->top_k = wrapper->promptContext.top_k;
|
||||
ctx->top_p = wrapper->promptContext.top_p;
|
||||
ctx->min_p = wrapper->promptContext.min_p;
|
||||
ctx->temp = wrapper->promptContext.temp;
|
||||
ctx->n_batch = wrapper->promptContext.n_batch;
|
||||
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
|
||||
|
@@ -39,6 +39,7 @@ struct llmodel_prompt_context {
|
||||
int32_t n_predict; // number of tokens to predict
|
||||
int32_t top_k; // top k logits to sample from
|
||||
float top_p; // nucleus sampling probability threshold
|
||||
float min_p; // Min P sampling
|
||||
float temp; // temperature to adjust model's output distribution
|
||||
int32_t n_batch; // number of predictions to generate in parallel
|
||||
float repeat_penalty; // penalty factor for repeated tokens
|
||||
|
Reference in New Issue
Block a user