add min_p sampling parameter (#2014)

Signed-off-by: Christopher Barrera <cb@arda.tx.rr.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
This commit is contained in:
chrisbarrera
2024-02-24 16:51:34 -06:00
committed by GitHub
parent a153cc5b25
commit f8b1069a1c
28 changed files with 176 additions and 14 deletions

View File

@@ -64,6 +64,7 @@ static int llama_sample_top_p_top_k(
int last_n_tokens_size,
int top_k,
float top_p,
float min_p,
float temp,
float repeat_penalty,
int32_t pos) {
@@ -83,6 +84,7 @@ static int llama_sample_top_p_top_k(
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_min_p(ctx, &candidates_p, min_p, 1);
llama_sample_temp(ctx, &candidates_p, temp);
return llama_sample_token(ctx, &candidates_p);
}
@@ -392,7 +394,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
return llama_sample_top_p_top_k(d_ptr->ctx,
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp,
promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1);
}