mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-10 12:59:09 +00:00
chat: faster KV shift, continue generating, fix stop sequences (#2781)
* Don't stop generating at end of context * Use llama_kv_cache ops to shift context * Fix and improve reverse prompt detection * Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
@@ -102,7 +102,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
: QObject{nullptr}
|
||||
, m_promptResponseTokens(0)
|
||||
, m_promptTokens(0)
|
||||
, m_isRecalc(false)
|
||||
, m_restoringFromText(false)
|
||||
, m_shouldBeLoaded(false)
|
||||
, m_forceUnloadModel(false)
|
||||
, m_markedForDeletion(false)
|
||||
@@ -712,17 +712,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "recalculate" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
if (m_isRecalc != isRecalc) {
|
||||
m_isRecalc = isRecalc;
|
||||
emit recalcChanged();
|
||||
}
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt)
|
||||
{
|
||||
if (m_restoreStateFromText) {
|
||||
@@ -776,7 +765,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
|
||||
emit promptProcessing();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
@@ -796,10 +784,12 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_timer->start();
|
||||
if (!docsContext.isEmpty()) {
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
|
||||
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||
}
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@@ -904,10 +894,9 @@ void ChatLLM::generateName()
|
||||
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
|
||||
LLModel::PromptContext ctx = m_ctx;
|
||||
m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(),
|
||||
promptFunc, responseFunc, recalcFunc, ctx);
|
||||
promptFunc, responseFunc, /*allowContextShift*/ false, ctx);
|
||||
std::string trimmed = trim_whitespace(m_nameResponse);
|
||||
if (trimmed != m_nameResponse) {
|
||||
m_nameResponse = trimmed;
|
||||
@@ -944,15 +933,6 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
|
||||
return words.size() <= 3;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleNameRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleQuestionPrompt(int32_t token)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@@ -991,15 +971,6 @@ bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response)
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleQuestionRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ChatLLM::generateQuestions(qint64 elapsed)
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
@@ -1019,12 +990,11 @@ void ChatLLM::generateQuestions(qint64 elapsed)
|
||||
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleQuestionRecalculate, this, std::placeholders::_1);
|
||||
LLModel::PromptContext ctx = m_ctx;
|
||||
QElapsedTimer totalTime;
|
||||
totalTime.start();
|
||||
m_llModelInfo.model->prompt(suggestedFollowUpPrompt,
|
||||
promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
|
||||
m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||
/*allowContextShift*/ false, ctx);
|
||||
elapsed += totalTime.elapsed();
|
||||
emit responseStopped(elapsed);
|
||||
}
|
||||
@@ -1039,15 +1009,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "system recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
@@ -1057,15 +1018,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
|
||||
{
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
|
||||
#endif
|
||||
Q_UNUSED(isRecalc);
|
||||
return false;
|
||||
}
|
||||
|
||||
// this function serialized the cached model state to disk.
|
||||
// we want to also serialize n_ctx, and read it at load time.
|
||||
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
@@ -1268,7 +1220,6 @@ void ChatLLM::processSystemPrompt()
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
||||
@@ -1294,7 +1245,7 @@ void ChatLLM::processSystemPrompt()
|
||||
#endif
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
|
||||
// use "%1%2" and not "%1" to avoid implicit whitespace
|
||||
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true);
|
||||
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true);
|
||||
m_ctx.n_predict = old_n_predict;
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
@@ -1311,14 +1262,13 @@ void ChatLLM::processRestoreStateFromText()
|
||||
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
|
||||
return;
|
||||
|
||||
m_isRecalc = true;
|
||||
emit recalcChanged();
|
||||
m_restoringFromText = true;
|
||||
emit restoringFromTextChanged();
|
||||
|
||||
m_stopGenerating = false;
|
||||
m_ctx = LLModel::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
|
||||
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
|
||||
|
||||
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
@@ -1351,7 +1301,7 @@ void ChatLLM::processRestoreStateFromText()
|
||||
auto responseText = response.second.toStdString();
|
||||
|
||||
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
|
||||
recalcFunc, m_ctx, false, &responseText);
|
||||
/*allowContextShift*/ true, m_ctx, false, &responseText);
|
||||
}
|
||||
|
||||
if (!m_stopGenerating) {
|
||||
@@ -1359,8 +1309,8 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_stateFromText.clear();
|
||||
}
|
||||
|
||||
m_isRecalc = false;
|
||||
emit recalcChanged();
|
||||
m_restoringFromText = false;
|
||||
emit restoringFromTextChanged();
|
||||
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
Reference in New Issue
Block a user