backend: fix extra spaces in tokenization and a CUDA crash (#2778)

Also potentially improves accuracy of BOS insertion, token cache, and logit indexing.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel
2024-08-01 10:46:36 -04:00
committed by GitHub
parent da59c9f5ea
commit 51bd01ae05
10 changed files with 46 additions and 36 deletions

View File

@@ -145,9 +145,8 @@ static int llama_sample_top_p_top_k(
float top_p,
float min_p,
float temp,
float repeat_penalty,
int32_t pos) {
auto logits = llama_get_logits_ith(ctx, pos);
float repeat_penalty) {
auto logits = llama_get_logits_ith(ctx, -1);
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Populate initial list of all candidates
std::vector<llama_token_data> candidates;
@@ -529,13 +528,21 @@ size_t LLamaModel::restoreState(const uint8_t *src)
return llama_set_state_data(d_ptr->ctx, const_cast<uint8_t*>(src));
}
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) const
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
{
const bool wantBOS = ctx.n_past == 0 && ctx.tokens.empty();
const bool useBOS = wantBOS && shouldAddBOS();
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || (
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
);
std::vector<LLModel::Token> fres(str.length() + 4);
auto fres_len = llama_tokenize(d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), useBOS, special);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
/*parse_special*/ special, /*insert_space*/ insertSpace
);
fres.resize(fres_len);
if (fres_len)
m_tokenize_last_token = fres.back();
return fres;
}
@@ -561,7 +568,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const
return llama_sample_top_p_top_k(d_ptr->ctx,
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.min_p, promptCtx.temp,
promptCtx.repeat_penalty, promptCtx.n_last_batch_tokens - 1);
promptCtx.repeat_penalty);
}
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
@@ -571,7 +578,6 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &toke
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
batch.n_tokens = tokens.size();
ctx.n_last_batch_tokens = tokens.size();
for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token [i] = tokens[i];
@@ -601,10 +607,7 @@ const std::vector<LLModel::Token> &LLamaModel::endTokens() const
bool LLamaModel::shouldAddBOS() const
{
int add_bos = llama_add_bos_token(d_ptr->model);
if (add_bos != -1) { return add_bos; }
auto vocab_type = llama_vocab_type(d_ptr->model);
return vocab_type == LLAMA_VOCAB_TYPE_SPM || vocab_type == LLAMA_VOCAB_TYPE_WPM;
return llama_add_bos_token(d_ptr->model);
}
int32_t LLamaModel::maxContextLength(std::string const &modelPath) const
@@ -946,7 +949,7 @@ void LLamaModel::embedInternal(
const llama_token bos_token = llama_token_bos(d_ptr->model);
const llama_token eos_token = llama_token_eos(d_ptr->model);
bool useBOS = shouldAddBOS();
bool useBOS = llama_add_bos_token(d_ptr->model);
bool useEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
// no EOS, optional BOS
@@ -954,13 +957,16 @@ void LLamaModel::embedInternal(
if (!text.empty() && text[0] != ' ') {
text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix
}
wantBOS &= useBOS;
tokens.resize(text.length()+4);
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false);
int32_t n_tokens = llama_tokenize_gpt4all(
d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), /*add_special*/ wantBOS,
/*parse_special*/ false, /*insert_space*/ false
);
if (n_tokens) {
(void)eos_token;
assert((useEOS && wantBOS) == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
(void)useBOS;
assert((useEOS && wantBOS && useBOS) == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
if (useEOS && wantBOS)
n_tokens--; // erase EOS/SEP
}