chat: faster KV shift, continue generating, fix stop sequences (#2781)

* Don't stop generating at end of context
* Use llama_kv_cache ops to shift context
* Fix and improve reverse prompt detection
* Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
Jared Van Bortel
2024-08-07 11:25:24 -04:00
committed by GitHub
parent 90de2d32f8
commit be66ec8ab5
16 changed files with 285 additions and 230 deletions

View File

@@ -531,10 +531,7 @@ size_t LLamaModel::restoreState(const uint8_t *src)
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || (
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
);
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
@@ -546,6 +543,12 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
return fres;
}
bool LLamaModel::isSpecialToken(Token id) const
{
return llama_token_get_attr(d_ptr->model, id)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN);
}
std::string LLamaModel::tokenToString(Token id) const
{
std::vector<char> result(8, 0);
@@ -595,6 +598,30 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &toke
return res == 0;
}
void LLamaModel::shiftContext(PromptContext &promptCtx)
{
// infinite text generation via context shifting
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
assert(n_discard > 0);
if (n_discard <= 0)
return;
std::cerr << "Llama: context full, swapping: n_past = " << n_past << ", n_keep = " << n_keep
<< ", n_discard = " << n_discard << "\n";
// erase the first n_discard tokens from the context
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
promptCtx.n_past = promptCtx.tokens.size();
}
int32_t LLamaModel::contextLength() const
{
return llama_n_ctx(d_ptr->ctx);