llamamodel: use greedy sampling when temp=0 (#2854)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-08-13 17:04:50 -04:00
committed by GitHub
parent 8ccf1fa2f5
commit 6518b33697
3 changed files with 27 additions and 10 deletions

View File

@@ -137,7 +137,7 @@ struct gpt_params {
bool use_mlock = false; // use mlock to keep model in memory
};
static int llama_sample_top_p_top_k(
static llama_token llama_sample_top_p_top_k(
llama_context *ctx,
const llama_token *last_n_tokens_data,
int last_n_tokens_size,
@@ -157,14 +157,22 @@ static int llama_sample_top_p_top_k(
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Sample repeat penalty
llama_sample_repetition_penalties(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty, 0.0f, 0.0f);
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_min_p(ctx, &candidates_p, min_p, 1);
llama_sample_temp(ctx, &candidates_p, temp);
return llama_sample_token(ctx, &candidates_p);
llama_token id;
if (temp == 0.0) {
// greedy sampling, no probs
id = llama_sample_token_greedy(ctx, &candidates_p);
} else {
// temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_min_p(ctx, &candidates_p, min_p, 1);
llama_sample_temp(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
return id;
}
const char *get_arch_name(gguf_context *ctx_gguf)