mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-07 03:20:26 +00:00
chat: faster KV shift, continue generating, fix stop sequences (#2781)
* Don't stop generating at end of context * Use llama_kv_cache ops to shift context * Fix and improve reverse prompt detection * Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
@@ -531,10 +531,7 @@ size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
|
||||
{
|
||||
bool atStart = m_tokenize_last_token == -1;
|
||||
bool insertSpace = atStart || (
|
||||
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
|
||||
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
|
||||
);
|
||||
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
|
||||
std::vector<LLModel::Token> fres(str.length() + 4);
|
||||
int32_t fres_len = llama_tokenize_gpt4all(
|
||||
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
|
||||
@@ -546,6 +543,12 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
|
||||
return fres;
|
||||
}
|
||||
|
||||
bool LLamaModel::isSpecialToken(Token id) const
|
||||
{
|
||||
return llama_token_get_attr(d_ptr->model, id)
|
||||
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN);
|
||||
}
|
||||
|
||||
std::string LLamaModel::tokenToString(Token id) const
|
||||
{
|
||||
std::vector<char> result(8, 0);
|
||||
@@ -595,6 +598,30 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &toke
|
||||
return res == 0;
|
||||
}
|
||||
|
||||
void LLamaModel::shiftContext(PromptContext &promptCtx)
|
||||
{
|
||||
// infinite text generation via context shifting
|
||||
|
||||
// erase up to n_ctx*contextErase tokens
|
||||
int n_keep = shouldAddBOS();
|
||||
int n_past = promptCtx.n_past;
|
||||
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
|
||||
|
||||
assert(n_discard > 0);
|
||||
if (n_discard <= 0)
|
||||
return;
|
||||
|
||||
std::cerr << "Llama: context full, swapping: n_past = " << n_past << ", n_keep = " << n_keep
|
||||
<< ", n_discard = " << n_discard << "\n";
|
||||
|
||||
// erase the first n_discard tokens from the context
|
||||
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);
|
||||
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
}
|
||||
|
||||
int32_t LLamaModel::contextLength() const
|
||||
{
|
||||
return llama_n_ctx(d_ptr->ctx);
|
||||
|
Reference in New Issue
Block a user