chat: faster KV shift, continue generating, fix stop sequences (#2781)

* Don't stop generating at end of context
* Use llama_kv_cache ops to shift context
* Fix and improve reverse prompt detection
* Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
Jared Van Bortel
2024-08-07 11:25:24 -04:00
committed by GitHub
parent 90de2d32f8
commit be66ec8ab5
16 changed files with 285 additions and 230 deletions

View File

@@ -102,7 +102,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr}
, m_promptResponseTokens(0)
, m_promptTokens(0)
, m_isRecalc(false)
, m_restoringFromText(false)
, m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_markedForDeletion(false)
@@ -712,17 +712,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
return !m_stopGenerating;
}
bool ChatLLM::handleRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "recalculate" << m_llmThread.objectName() << isRecalc;
#endif
if (m_isRecalc != isRecalc) {
m_isRecalc = isRecalc;
emit recalcChanged();
}
return !m_stopGenerating;
}
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt)
{
if (m_restoreStateFromText) {
@@ -776,7 +765,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
emit promptProcessing();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
@@ -796,10 +784,12 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_timer->start();
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -904,10 +894,9 @@ void ChatLLM::generateName()
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(),
promptFunc, responseFunc, recalcFunc, ctx);
promptFunc, responseFunc, /*allowContextShift*/ false, ctx);
std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) {
m_nameResponse = trimmed;
@@ -944,15 +933,6 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
return words.size() <= 3;
}
bool ChatLLM::handleNameRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return true;
}
bool ChatLLM::handleQuestionPrompt(int32_t token)
{
#if defined(DEBUG)
@@ -991,15 +971,6 @@ bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response)
return true;
}
bool ChatLLM::handleQuestionRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return true;
}
void ChatLLM::generateQuestions(qint64 elapsed)
{
Q_ASSERT(isModelLoaded());
@@ -1019,12 +990,11 @@ void ChatLLM::generateQuestions(qint64 elapsed)
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleQuestionRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
QElapsedTimer totalTime;
totalTime.start();
m_llModelInfo.model->prompt(suggestedFollowUpPrompt,
promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ false, ctx);
elapsed += totalTime.elapsed();
emit responseStopped(elapsed);
}
@@ -1039,15 +1009,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "system recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
{
#if defined(DEBUG)
@@ -1057,15 +1018,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
// this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
@@ -1268,7 +1220,6 @@ void ChatLLM::processSystemPrompt()
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
@@ -1294,7 +1245,7 @@ void ChatLLM::processSystemPrompt()
#endif
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
// use "%1%2" and not "%1" to avoid implicit whitespace
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true);
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true);
m_ctx.n_predict = old_n_predict;
#if defined(DEBUG)
printf("\n");
@@ -1311,14 +1262,13 @@ void ChatLLM::processRestoreStateFromText()
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;
m_isRecalc = true;
emit recalcChanged();
m_restoringFromText = true;
emit restoringFromTextChanged();
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
@@ -1351,7 +1301,7 @@ void ChatLLM::processRestoreStateFromText()
auto responseText = response.second.toStdString();
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
recalcFunc, m_ctx, false, &responseText);
/*allowContextShift*/ true, m_ctx, false, &responseText);
}
if (!m_stopGenerating) {
@@ -1359,8 +1309,8 @@ void ChatLLM::processRestoreStateFromText()
m_stateFromText.clear();
}
m_isRecalc = false;
emit recalcChanged();
m_restoringFromText = false;
emit restoringFromTextChanged();
m_pristineLoadedState = false;
}