chat: faster KV shift, continue generating, fix stop sequences (#2781)

* Don't stop generating at end of context
* Use llama_kv_cache ops to shift context
* Fix and improve reverse prompt detection
* Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
Jared Van Bortel
2024-08-07 11:25:24 -04:00
committed by GitHub
parent 90de2d32f8
commit be66ec8ab5
16 changed files with 285 additions and 230 deletions

View File

@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.16)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
if(APPLE)
@@ -31,7 +31,6 @@ project(gpt4all VERSION ${APP_VERSION_BASE} LANGUAGES CXX C)
set(CMAKE_AUTOMOC ON)
set(CMAKE_AUTORCC ON)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
option(GPT4ALL_TRANSLATIONS OFF "Build with translations")
option(GPT4ALL_LOCALHOST OFF "Build installer for localhost repo")

View File

@@ -62,7 +62,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
@@ -252,9 +252,9 @@ void Chat::serverNewPromptResponsePair(const QString &prompt)
m_chatModel->appendResponse("Response: ", prompt);
}
bool Chat::isRecalc() const
bool Chat::restoringFromText() const
{
return m_llmodel->isRecalc();
return m_llmodel->restoringFromText();
}
void Chat::unloadAndDeleteLater()
@@ -320,10 +320,10 @@ void Chat::generatedQuestionFinished(const QString &question)
emit generatedQuestionsChanged();
}
void Chat::handleRecalculating()
void Chat::handleRestoringFromText()
{
Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} });
emit recalcChanged();
emit restoringFromTextChanged();
}
void Chat::handleModelLoadingError(const QString &error)

View File

@@ -27,7 +27,7 @@ class Chat : public QObject
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged)
Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged)
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
@@ -88,7 +88,7 @@ public:
ResponseState responseState() const;
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &modelInfo);
bool isRecalc() const;
bool restoringFromText() const;
Q_INVOKABLE void unloadModel();
Q_INVOKABLE void reloadModel();
@@ -144,7 +144,7 @@ Q_SIGNALS:
void processSystemPromptRequested();
void modelChangeRequested(const ModelInfo &modelInfo);
void modelInfoChanged();
void recalcChanged();
void restoringFromTextChanged();
void loadDefaultModelRequested();
void loadModelRequested(const ModelInfo &modelInfo);
void generateNameRequested();
@@ -167,7 +167,7 @@ private Q_SLOTS:
void responseStopped(qint64 promptResponseMs);
void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &question);
void handleRecalculating();
void handleRestoringFromText();
void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);

View File

@@ -90,13 +90,13 @@ void ChatAPI::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::string *fakeReply) {
Q_UNUSED(promptCallback);
Q_UNUSED(recalculateCallback);
Q_UNUSED(allowContextShift);
Q_UNUSED(special);
if (!isModelLoaded()) {

View File

@@ -69,7 +69,7 @@ public:
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &ctx,
bool special,
std::string *fakeReply) override;
@@ -97,38 +97,57 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override {
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override
{
(void)ctx;
(void)str;
(void)special;
throw std::logic_error("not implemented");
}
std::string tokenToString(Token id) const override {
bool isSpecialToken(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
Token sampleToken(PromptContext &ctx) const override {
std::string tokenToString(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
Token sampleToken(PromptContext &ctx) const override
{
(void)ctx;
throw std::logic_error("not implemented");
}
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override {
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
{
(void)ctx;
(void)tokens;
throw std::logic_error("not implemented");
}
int32_t contextLength() const override {
void shiftContext(PromptContext &promptCtx) override
{
(void)promptCtx;
throw std::logic_error("not implemented");
}
const std::vector<Token> &endTokens() const override {
int32_t contextLength() const override
{
throw std::logic_error("not implemented");
}
bool shouldAddBOS() const override {
const std::vector<Token> &endTokens() const override
{
throw std::logic_error("not implemented");
}
bool shouldAddBOS() const override
{
throw std::logic_error("not implemented");
}

View File

@@ -102,7 +102,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr}
, m_promptResponseTokens(0)
, m_promptTokens(0)
, m_isRecalc(false)
, m_restoringFromText(false)
, m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_markedForDeletion(false)
@@ -712,17 +712,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
return !m_stopGenerating;
}
bool ChatLLM::handleRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "recalculate" << m_llmThread.objectName() << isRecalc;
#endif
if (m_isRecalc != isRecalc) {
m_isRecalc = isRecalc;
emit recalcChanged();
}
return !m_stopGenerating;
}
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt)
{
if (m_restoreStateFromText) {
@@ -776,7 +765,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
emit promptProcessing();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
@@ -796,10 +784,12 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_timer->start();
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@@ -904,10 +894,9 @@ void ChatLLM::generateName()
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(),
promptFunc, responseFunc, recalcFunc, ctx);
promptFunc, responseFunc, /*allowContextShift*/ false, ctx);
std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) {
m_nameResponse = trimmed;
@@ -944,15 +933,6 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
return words.size() <= 3;
}
bool ChatLLM::handleNameRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return true;
}
bool ChatLLM::handleQuestionPrompt(int32_t token)
{
#if defined(DEBUG)
@@ -991,15 +971,6 @@ bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response)
return true;
}
bool ChatLLM::handleQuestionRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return true;
}
void ChatLLM::generateQuestions(qint64 elapsed)
{
Q_ASSERT(isModelLoaded());
@@ -1019,12 +990,11 @@ void ChatLLM::generateQuestions(qint64 elapsed)
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleQuestionRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
QElapsedTimer totalTime;
totalTime.start();
m_llModelInfo.model->prompt(suggestedFollowUpPrompt,
promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ false, ctx);
elapsed += totalTime.elapsed();
emit responseStopped(elapsed);
}
@@ -1039,15 +1009,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "system recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
{
#if defined(DEBUG)
@@ -1057,15 +1018,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
// this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
@@ -1268,7 +1220,6 @@ void ChatLLM::processSystemPrompt()
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
@@ -1294,7 +1245,7 @@ void ChatLLM::processSystemPrompt()
#endif
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
// use "%1%2" and not "%1" to avoid implicit whitespace
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true);
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true);
m_ctx.n_predict = old_n_predict;
#if defined(DEBUG)
printf("\n");
@@ -1311,14 +1262,13 @@ void ChatLLM::processRestoreStateFromText()
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;
m_isRecalc = true;
emit recalcChanged();
m_restoringFromText = true;
emit restoringFromTextChanged();
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
@@ -1351,7 +1301,7 @@ void ChatLLM::processRestoreStateFromText()
auto responseText = response.second.toStdString();
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
recalcFunc, m_ctx, false, &responseText);
/*allowContextShift*/ true, m_ctx, false, &responseText);
}
if (!m_stopGenerating) {
@@ -1359,8 +1309,8 @@ void ChatLLM::processRestoreStateFromText()
m_stateFromText.clear();
}
m_isRecalc = false;
emit recalcChanged();
m_restoringFromText = false;
emit restoringFromTextChanged();
m_pristineLoadedState = false;
}

View File

@@ -93,7 +93,7 @@ class Chat;
class ChatLLM : public QObject
{
Q_OBJECT
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
@@ -121,7 +121,7 @@ public:
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &info);
bool isRecalc() const { return m_isRecalc; }
bool restoringFromText() const { return m_restoringFromText; }
void acquireModel();
void resetModel();
@@ -172,7 +172,7 @@ public Q_SLOTS:
void processRestoreStateFromText();
Q_SIGNALS:
void recalcChanged();
void restoringFromTextChanged();
void loadedModelInfoChanged();
void modelLoadingPercentageChanged(float);
void modelLoadingError(const QString &error);
@@ -201,19 +201,14 @@ protected:
int32_t repeat_penalty_tokens);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc);
bool handleNamePrompt(int32_t token);
bool handleNameResponse(int32_t token, const std::string &response);
bool handleNameRecalculate(bool isRecalc);
bool handleSystemPrompt(int32_t token);
bool handleSystemResponse(int32_t token, const std::string &response);
bool handleSystemRecalculate(bool isRecalc);
bool handleRestoreStateFromTextPrompt(int32_t token);
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
bool handleRestoreStateFromTextRecalculate(bool isRecalc);
bool handleQuestionPrompt(int32_t token);
bool handleQuestionResponse(int32_t token, const std::string &response);
bool handleQuestionRecalculate(bool isRecalc);
void saveState();
void restoreState();
@@ -236,7 +231,7 @@ private:
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_isRecalc;
std::atomic<bool> m_restoringFromText; // status indication
std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion;
bool m_isServer;

View File

@@ -834,7 +834,7 @@ Rectangle {
to: 360
duration: 1000
loops: Animation.Infinite
running: currentResponse && (currentChat.responseInProgress || currentChat.isRecalc)
running: currentResponse && (currentChat.responseInProgress || currentChat.restoringFromText)
}
}
}
@@ -867,13 +867,13 @@ Rectangle {
color: theme.mutedTextColor
}
RowLayout {
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.isRecalc)
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.restoringFromText)
Text {
color: theme.mutedTextColor
font.pixelSize: theme.fontSizeLarger
text: {
if (currentChat.isRecalc)
return qsTr("recalculating context ...");
if (currentChat.restoringFromText)
return qsTr("restoring from text ...");
switch (currentChat.responseState) {
case Chat.ResponseStopped: return qsTr("response stopped ...");
case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", "));
@@ -1861,7 +1861,7 @@ Rectangle {
}
}
function sendMessage() {
if (textInput.text === "" || currentChat.responseInProgress || currentChat.isRecalc)
if (textInput.text === "" || currentChat.responseInProgress || currentChat.restoringFromText)
return
currentChat.stopGenerating()