backend: dedupe tokenizing code in mpt/gptj

This commit is contained in:
Aaron Miller
2023-05-15 17:42:20 -07:00
committed by AT
parent 6182026c70
commit d14936bfd6
4 changed files with 6 additions and 102 deletions

View File

@@ -219,6 +219,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
@@ -227,7 +228,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
double temp,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
int n_logits = actualVocabSize;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data() + logits.size() - n_logits;
@@ -312,4 +313,4 @@ gpt_vocab::id gpt_sample_top_k_top_p(
int idx = dist(rng);
return logits_id[idx].second;
}
}